Commit 85c97136 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix various bugs...

parent 02a36166
...@@ -174,11 +174,13 @@ void mutual_information_kernel( ...@@ -174,11 +174,13 @@ void mutual_information_kernel(
t_block_begin = (iter - block) * BLOCK_SIZE; t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0); bool is_origin_block = (s_block_begin * t_block_begin == 0);
assert(b < B && b >= 0 && t_block_begin >= 0); // TODO: remove. __syncthreads();
if (boundary.size(0) != 0 && threadIdx.x < 4) if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0],
t_begin = boundary_buf[1], t_begin = boundary_buf[1],
s_end = boundary_buf[2], s_end = boundary_buf[2],
...@@ -225,34 +227,36 @@ void mutual_information_kernel( ...@@ -225,34 +227,36 @@ void mutual_information_kernel(
// needed). This is the context from previously computed blocks of the // needed). This is the context from previously computed blocks of the
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin - // image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin -
// 1][t + t_block_begin - 1] - normalizer. // 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x < 64) { // 64 == warp size. First half of threads... if (threadIdx.x <= BLOCK_SIZE) {
if (threadIdx.x <= BLOCK_SIZE) { // s_in_p_buf are simply the indexes into p_buf
// s_in_p_buf are simply the indexes into p_buf int s_in_p_buf = threadIdx.x,
int s_in_p_buf = threadIdx.x, t_in_p_buf = 0,
t_in_p_buf = 0, s = s_in_p_buf + s_block_begin - 1,
s = s_in_p_buf + s_block_begin - 1, t = t_in_p_buf + t_block_begin - 1;
t = t_in_p_buf + t_block_begin - 1; scalar_t this_p = -INFINITY;
scalar_t this_p = -INFINITY; if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) && static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) this_p = p[b][s][t];
this_p = p[b][s][t]; /*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
p_buf[s_in_p_buf][t_in_p_buf] = this_p; (float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
} }
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} else { } else {
// Another warp handles the other leg. Checking as unsigned // Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE // tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <= if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) { static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0, int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64, t_in_p_buf = (int)threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
// The if-statement below just guards against out-of-range memory
// accesses, it does not guarantee that we really need these values.
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) && if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
this_p = p[b][s][t]; this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
}
p_buf[s_in_p_buf][t_in_p_buf] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} }
} }
...@@ -264,7 +268,9 @@ void mutual_information_kernel( ...@@ -264,7 +268,9 @@ void mutual_information_kernel(
// zero, and then exponentiate. We'll do everything in non-log space, for // zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data. // speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 : scalar_t normalizer = (is_origin_block ? 0.0 :
max(px_buf[0][1], px_buf[1][0])); max(p_buf[0][1], p_buf[1][0]));
__syncthreads();
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements // Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0, // where at one index is 0. The [0][0] element is special; we write 0.0,
...@@ -286,8 +292,12 @@ void mutual_information_kernel( ...@@ -286,8 +292,12 @@ void mutual_information_kernel(
p_buf[0][t] = exp(p_buf[0][t] - normalizer); p_buf[0][t] = exp(p_buf[0][t] - normalizer);
} }
__syncthreads();
// from here to the next __syncthreads(), only the 1st warp should be active
// so we shouldn't need to synchronize. (implicit within-warp
// synchronization).
if (threadIdx.x == 0 && is_origin_block) { if (threadIdx.x == 0) {
// This if-statement is an optimization and modification of the loop below // This if-statement is an optimization and modification of the loop below
// for the value i == 0, i.e. inner-iteration == 0. The modification is // for the value i == 0, i.e. inner-iteration == 0. The modification is
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block", // to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
...@@ -309,55 +319,58 @@ void mutual_information_kernel( ...@@ -309,55 +319,58 @@ void mutual_information_kernel(
} }
int s = threadIdx.x; int s = threadIdx.x;
if (s < block_S) { for (int i = 1; i < block_S + block_T - 1; ++i) {
for (int i = 1; i < block_S + block_T - 1; ++i) { __syncwarp();
// i is the inner iteration, which corresponds to the (s + t) indexes of // i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes // the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case // positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1) // above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this // and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure // part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like // out a very meaningful way for more threads to do work, that looked like
// it would really spead things up. // it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot, // So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the // but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()). // inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x, int t = i - s;
t = i - s; if (s < block_S &&
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) { // p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial // row and column for context from previous blocks. Taking into account
// row and column for context from previous blocks. Taking into account // the way these buffers relate to the tensors p, px and py, and
// the way these buffers relate to the tensors p, px and py, and // ignoring `normalizer`, code below can be interpreted as follows,
// ignoring `normalizer`, code below can be interpreted as follows, // writing sbb for s_block_begin and tbb for t_block_begin:
// writing sbb for s_block_begin and tbb for t_block_begin: //
// // p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb], // p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1] //
// // where you can see that apart from the offsets of tbb and sbb, this is
// where you can see that apart from the offsets of tbb and sbb, this is // the same as the recursion defined for p in
// the same as the recursion defined for p in // mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1 #if 1
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t]; p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
/*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = %f, p_buf[s][t+1] = %f, "
"px_buf[s][t] = %f, p_buf[s + 1][t] = %f, py_buf[s][t] = %f\n",
(int)threadIdx.x, i, s, t, (float)p_buf[s+1][t+1], (float)p_buf[s][t+1],
(float)px_buf[s][t], (float)p_buf[s+1][t], (float)py_buf[s][t]);*/
#else #else
// This is an optimization of the statement above (the other half of // This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid // this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory. // the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t]; p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater, // The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this // so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where // thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1. // the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t; p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif #endif
// We don't need to do __syncthreads() in this loop because all the // We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future, // threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here). // if NVidia changes some things, we might need to sync here).
}
} }
__syncthreads();
} }
__syncthreads();
// Write out the data to p; check that nothing has gone out of numerical // Write out the data to p; check that nothing has gone out of numerical
// range, and write 'panic' flag if it has. // range, and write 'panic' flag if it has.
...@@ -369,9 +382,11 @@ void mutual_information_kernel( ...@@ -369,9 +382,11 @@ void mutual_information_kernel(
if (s_in_block < block_S && t_in_block < block_T) { if (s_in_block < block_S && t_in_block < block_T) {
float this_p = p_buf[s_in_block + 1][t_in_block + 1]; float this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p); p[b][s][t] = normalizer + log(this_p);
// If this_p is infinity, NaN or zero... // If this_p is infinity or NaN..
if (this_p - this_p != 0 || this_p == 0) if (this_p - this_p != 0) {
printf("[panic] threadIdx.x = %d, this_p = %f\n", (int)threadIdx.x, (float)this_p);
p_buf[0][0] = 1.0; // This is a "panic" flag. p_buf[0][0] = 1.0; // This is a "panic" flag.
}
} }
} }
...@@ -391,6 +406,8 @@ void mutual_information_kernel( ...@@ -391,6 +406,8 @@ void mutual_information_kernel(
} }
if (p_buf[0][0] != 0.0) { if (p_buf[0][0] != 0.0) {
if (threadIdx.x == 0)
printf("Panic flag set, value = %f\n", (float)p_buf[0][0]); // TEMP
// The "panic" flag is set. We need to re-do the computation using log-add. // The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main // This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching // memory. This code should very rarely be reached; and anyway, caching
...@@ -398,9 +415,10 @@ void mutual_information_kernel( ...@@ -398,9 +415,10 @@ void mutual_information_kernel(
int s_in_block = threadIdx.x; int s_in_block = threadIdx.x;
if (s_in_block < block_S) { if (s_in_block < block_S) {
for (int i = 0; i < block_S + block_T - 1; ++i) { for (int i = 0; i < block_S + block_T - 1; ++i) {
__syncwarp();
int t_in_block = i - s_in_block; int t_in_block = i - s_in_block;
if (s_in_block < block_S && if (static_cast<unsigned int>(t_in_block) <
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin, int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin; t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]), float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
...@@ -717,18 +735,17 @@ void mutual_information_backward_kernel( ...@@ -717,18 +735,17 @@ void mutual_information_backward_kernel(
{ {
int s = threadIdx.x; int s = threadIdx.x;
if (s < block_S) { for (int i = first_iter; i >= 0; --i) {
for (int i = first_iter; i >= 0; --i) { __syncwarp();
int t = i - s; int t = i - s;
if (t >= 0) { if (t >= 0 && s < block_S) {
// The following statement is really operating on the gradients; // The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin // it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.: // on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + // p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t] // p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] + p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]); p_buf[s][t + 1] * py_buf[s][t]);
}
} }
} }
} }
......
...@@ -12,16 +12,16 @@ def test_mutual_information_basic(): ...@@ -12,16 +12,16 @@ def test_mutual_information_basic():
for device in [ torch.device('cpu'), torch.device('cuda:0') ]: for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device) print("dtype = ", dtype, ", device = ", device)
B = 2 B = 2
S = 17 S = 33
T = 17 T = 33
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device) boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py, None) m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary) #m = mutual_information_recursion(px, py, boundary)
print("m = ", m) print("m = ", m, ", size = ", m.shape)
print("exp(m) = ", m.exp()) print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment