Commit 85c97136 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix various bugs...

parent 02a36166
......@@ -174,11 +174,13 @@ void mutual_information_kernel(
t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
assert(b < B && b >= 0 && t_block_begin >= 0); // TODO: remove.
__syncthreads();
if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads();
int s_begin = boundary_buf[0],
t_begin = boundary_buf[1],
s_end = boundary_buf[2],
......@@ -225,34 +227,36 @@ void mutual_information_kernel(
// needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x < 64) { // 64 == warp size. First half of threads...
if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf
int s_in_p_buf = threadIdx.x,
t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf
int s_in_p_buf = threadIdx.x,
t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
}
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} else {
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64,
t_in_p_buf = (int)threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
}
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
}
......@@ -264,7 +268,9 @@ void mutual_information_kernel(
// zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 :
max(px_buf[0][1], px_buf[1][0]));
max(p_buf[0][1], p_buf[1][0]));
__syncthreads();
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
......@@ -286,8 +292,12 @@ void mutual_information_kernel(
p_buf[0][t] = exp(p_buf[0][t] - normalizer);
}
__syncthreads();
// from here to the next __syncthreads(), only the 1st warp should be active
// so we shouldn't need to synchronize. (implicit within-warp
// synchronization).
if (threadIdx.x == 0 && is_origin_block) {
if (threadIdx.x == 0) {
// This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification is
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
......@@ -309,55 +319,58 @@ void mutual_information_kernel(
}
int s = threadIdx.x;
if (s < block_S) {
for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x,
t = i - s;
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
for (int i = 1; i < block_S + block_T - 1; ++i) {
__syncwarp();
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int t = i - s;
if (s < block_S &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
/*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = %f, p_buf[s][t+1] = %f, "
"px_buf[s][t] = %f, p_buf[s + 1][t] = %f, py_buf[s][t] = %f\n",
(int)threadIdx.x, i, s, t, (float)p_buf[s+1][t+1], (float)p_buf[s][t+1],
(float)px_buf[s][t], (float)p_buf[s+1][t], (float)py_buf[s][t]);*/
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
__syncthreads();
}
__syncthreads();
// Write out the data to p; check that nothing has gone out of numerical
// range, and write 'panic' flag if it has.
......@@ -369,9 +382,11 @@ void mutual_information_kernel(
if (s_in_block < block_S && t_in_block < block_T) {
float this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p);
// If this_p is infinity, NaN or zero...
if (this_p - this_p != 0 || this_p == 0)
// If this_p is infinity or NaN..
if (this_p - this_p != 0) {
printf("[panic] threadIdx.x = %d, this_p = %f\n", (int)threadIdx.x, (float)this_p);
p_buf[0][0] = 1.0; // This is a "panic" flag.
}
}
}
......@@ -391,6 +406,8 @@ void mutual_information_kernel(
}
if (p_buf[0][0] != 0.0) {
if (threadIdx.x == 0)
printf("Panic flag set, value = %f\n", (float)p_buf[0][0]); // TEMP
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
......@@ -398,9 +415,10 @@ void mutual_information_kernel(
int s_in_block = threadIdx.x;
if (s_in_block < block_S) {
for (int i = 0; i < block_S + block_T - 1; ++i) {
__syncwarp();
int t_in_block = i - s_in_block;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
if (static_cast<unsigned int>(t_in_block) <
static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
......@@ -717,18 +735,17 @@ void mutual_information_backward_kernel(
{
int s = threadIdx.x;
if (s < block_S) {
for (int i = first_iter; i >= 0; --i) {
int t = i - s;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]);
}
for (int i = first_iter; i >= 0; --i) {
__syncwarp();
int t = i - s;
if (t >= 0 && s < block_S) {
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]);
}
}
}
......
......@@ -12,16 +12,16 @@ def test_mutual_information_basic():
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 17
T = 17
S = 33
T = 33
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary)
print("m = ", m)
print("m = ", m, ", size = ", m.shape)
print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment