Commit 02a36166 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix more bugs, add some debug statement.

parent 9ac065f0
......@@ -168,11 +168,14 @@ void mutual_information_kernel(
b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
// block < num_blocks_this_iter, so iter - block >= 0.
int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
assert(b < B && b >= 0 && t_block_begin >= 0); // TODO: remove.
if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads();
......@@ -233,10 +236,13 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
} else { // Another warp handles the other leg
if (int(threadIdx.x) - 64 <= BLOCK_SIZE) {
} else {
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1,
......@@ -247,7 +253,7 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
}
......@@ -269,11 +275,15 @@ void mutual_information_kernel(
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 :
exp(p_buf[threadIdx.x][0] - normalizer));
} else if (int(threadIdx.x) - 64 < BLOCK_SIZE) {
int s = threadIdx.x;
p_buf[s][0] = (s == 0 ? 0.0 :
exp(p_buf[s][0] - normalizer));
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <
static_cast<unsigned int>(BLOCK_SIZE)) {
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int t = (int)threadIdx.x - 64 + 1; // 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer);
p_buf[0][t] = exp(p_buf[0][t] - normalizer);
}
......@@ -298,6 +308,8 @@ void mutual_information_kernel(
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
}
int s = threadIdx.x;
if (s < block_S) {
for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
......@@ -343,6 +355,7 @@ void mutual_information_kernel(
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
}
__syncthreads();
}
......@@ -382,9 +395,10 @@ void mutual_information_kernel(
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
int s_in_block = threadIdx.x;
if (s_in_block < block_S) {
for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x,
t_in_block = i - s_in_block;
int t_in_block = i - s_in_block;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
......@@ -400,6 +414,7 @@ void mutual_information_kernel(
p[b][s][t] = this_p;
}
}
}
if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
......@@ -700,9 +715,11 @@ void mutual_information_backward_kernel(
--first_iter;
}
{
int s = threadIdx.x;
if (s < block_S) {
for (int i = first_iter; i >= 0; --i) {
int s = i,
t = i - threadIdx.x;
int t = i - s;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
......@@ -713,6 +730,8 @@ void mutual_information_backward_kernel(
p_buf[s][t + 1] * py_buf[s][t]);
}
}
}
}
// Write out p_grad, px_grad and py_grad.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
......
......@@ -12,13 +12,14 @@ def test_mutual_information_basic():
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 4
T = 5
S = 17
T = 17
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py, boundary)
m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary)
print("m = ", m)
print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment