Commit 02a36166 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix more bugs, add some debug statement.

parent 9ac065f0
...@@ -168,11 +168,14 @@ void mutual_information_kernel( ...@@ -168,11 +168,14 @@ void mutual_information_kernel(
b = batch_block_iter % B; // b is the index into the batch b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter // Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0. // <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
// block < num_blocks_this_iter, so iter - block >= 0.
int s_block_begin = block * BLOCK_SIZE, int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE; t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0); bool is_origin_block = (s_block_begin * t_block_begin == 0);
assert(b < B && b >= 0 && t_block_begin >= 0); // TODO: remove.
if (boundary.size(0) != 0 && threadIdx.x < 4) if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
...@@ -233,10 +236,13 @@ void mutual_information_kernel( ...@@ -233,10 +236,13 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) && if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t]; this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} }
} else { // Another warp handles the other leg } else {
if (int(threadIdx.x) - 64 <= BLOCK_SIZE) { // Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0, int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64, t_in_p_buf = threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
...@@ -247,7 +253,7 @@ void mutual_information_kernel( ...@@ -247,7 +253,7 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) && if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t]; this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} }
} }
...@@ -269,11 +275,15 @@ void mutual_information_kernel( ...@@ -269,11 +275,15 @@ void mutual_information_kernel(
// p_buf[0][0] = 0.0; <-- for search purposes. // p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a // We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator. // 'panic' indicator.
p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 : int s = threadIdx.x;
exp(p_buf[threadIdx.x][0] - normalizer)); p_buf[s][0] = (s == 0 ? 0.0 :
} else if (int(threadIdx.x) - 64 < BLOCK_SIZE) { exp(p_buf[s][0] - normalizer));
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <
static_cast<unsigned int>(BLOCK_SIZE)) {
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int t = (int)threadIdx.x - 64 + 1; // 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above. // this happens in a different warp so can be in parallel to the code above.
p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer); p_buf[0][t] = exp(p_buf[0][t] - normalizer);
} }
...@@ -298,6 +308,8 @@ void mutual_information_kernel( ...@@ -298,6 +308,8 @@ void mutual_information_kernel(
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0]; p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
} }
int s = threadIdx.x;
if (s < block_S) {
for (int i = 1; i < block_S + block_T - 1; ++i) { for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of // i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes // the elements within the block that we write. So i == 0 writes
...@@ -343,6 +355,7 @@ void mutual_information_kernel( ...@@ -343,6 +355,7 @@ void mutual_information_kernel(
// threads that are active are in the same warp. (However, in future, // threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here). // if NVidia changes some things, we might need to sync here).
} }
}
__syncthreads(); __syncthreads();
} }
...@@ -382,9 +395,10 @@ void mutual_information_kernel( ...@@ -382,9 +395,10 @@ void mutual_information_kernel(
// This time we won't use the buffers, we'll just load and save from main // This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching // memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit. // should help us quite a bit.
int s_in_block = threadIdx.x;
if (s_in_block < block_S) {
for (int i = 0; i < block_S + block_T - 1; ++i) { for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x, int t_in_block = i - s_in_block;
t_in_block = i - s_in_block;
if (s_in_block < block_S && if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin, int s = s_in_block + s_block_begin,
...@@ -400,6 +414,7 @@ void mutual_information_kernel( ...@@ -400,6 +414,7 @@ void mutual_information_kernel(
p[b][s][t] = this_p; p[b][s][t] = this_p;
} }
} }
}
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence. // Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow. // This is only reached in the 'panic situation' where we had overflow.
...@@ -700,9 +715,11 @@ void mutual_information_backward_kernel( ...@@ -700,9 +715,11 @@ void mutual_information_backward_kernel(
--first_iter; --first_iter;
} }
{
int s = threadIdx.x;
if (s < block_S) {
for (int i = first_iter; i >= 0; --i) { for (int i = first_iter; i >= 0; --i) {
int s = i, int t = i - s;
t = i - threadIdx.x;
if (t >= 0) { if (t >= 0) {
// The following statement is really operating on the gradients; // The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin // it corresponds, with offsets of s_block_begin and t_block_begin
...@@ -713,6 +730,8 @@ void mutual_information_backward_kernel( ...@@ -713,6 +730,8 @@ void mutual_information_backward_kernel(
p_buf[s][t + 1] * py_buf[s][t]); p_buf[s][t + 1] * py_buf[s][t]);
} }
} }
}
}
// Write out p_grad, px_grad and py_grad. // Write out p_grad, px_grad and py_grad.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
......
...@@ -12,13 +12,14 @@ def test_mutual_information_basic(): ...@@ -12,13 +12,14 @@ def test_mutual_information_basic():
for device in [ torch.device('cpu'), torch.device('cuda:0') ]: for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device) print("dtype = ", dtype, ", device = ", device)
B = 2 B = 2
S = 4 S = 17
T = 5 T = 17
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device) boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py, boundary) m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary)
print("m = ", m) print("m = ", m)
print("exp(m) = ", m.exp()) print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment