Commit 02a36166 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix more bugs, add some debug statement.

parent 9ac065f0
......@@ -168,11 +168,14 @@ void mutual_information_kernel(
b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, so iter - block >= 0.
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
// block < num_blocks_this_iter, so iter - block >= 0.
int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE;
bool is_origin_block = (s_block_begin * t_block_begin == 0);
assert(b < B && b >= 0 && t_block_begin >= 0); // TODO: remove.
if (boundary.size(0) != 0 && threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads();
......@@ -233,10 +236,13 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
} else { // Another warp handles the other leg
if (int(threadIdx.x) - 64 <= BLOCK_SIZE) {
} else {
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0,
t_in_p_buf = threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1,
......@@ -247,7 +253,7 @@ void mutual_information_kernel(
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end))
this_p = p[b][s][t];
p_buf[threadIdx.x][0] = this_p;
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
}
......@@ -269,11 +275,15 @@ void mutual_information_kernel(
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
p_buf[threadIdx.x][0] = (threadIdx.x == 0 ? 0.0 :
exp(p_buf[threadIdx.x][0] - normalizer));
} else if (int(threadIdx.x) - 64 < BLOCK_SIZE) {
int s = threadIdx.x;
p_buf[s][0] = (s == 0 ? 0.0 :
exp(p_buf[s][0] - normalizer));
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <
static_cast<unsigned int>(BLOCK_SIZE)) {
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int t = (int)threadIdx.x - 64 + 1; // 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
p_buf[0][threadIdx.x + 1] = exp(p_buf[0][threadIdx.x + 1] - normalizer);
p_buf[0][t] = exp(p_buf[0][t] - normalizer);
}
......@@ -298,50 +308,53 @@ void mutual_information_kernel(
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
}
for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x,
t = i - s;
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
int s = threadIdx.x;
if (s < block_S) {
for (int i = 1; i < block_S + block_T - 1; ++i) {
// i is the inner iteration, which corresponds to the (s + t) indexes of
// the elements within the block that we write. So i == 0 writes
// positions (s, t) == (0, 0) (but we treated i == 0 as a special case
// above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1)
// and (2, 1); and so on. Note: not many threads participate in this
// part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure
// out a very meaningful way for more threads to do work, that looked like
// it would really spead things up.
// So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot,
// but we do at least do the I/O in an efficient way and keep the
// inner loop simple and fast (e.g. no exp() or log()).
int s = threadIdx.x,
t = i - s;
if (static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and
// ignoring `normalizer`, code below can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin:
//
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1]
//
// where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above.
#if 1
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
// We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here).
}
}
__syncthreads();
}
......@@ -382,22 +395,24 @@ void mutual_information_kernel(
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
for (int i = 0; i < block_S + block_T - 1; ++i) {
int s_in_block = threadIdx.x,
t_in_block = i - s_in_block;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
this_px = (s == 0 ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]),
this_py = (t == 0 ? -INFINITY : py[b][s][t - 1]);
float this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
this_p = 0.0;
p[b][s][t] = this_p;
int s_in_block = threadIdx.x;
if (s_in_block < block_S) {
for (int i = 0; i < block_S + block_T - 1; ++i) {
int t_in_block = i - s_in_block;
if (s_in_block < block_S &&
static_cast<unsigned int>(t_in_block) < static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
this_px = (s == 0 ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]),
this_py = (t == 0 ? -INFINITY : py[b][s][t - 1]);
float this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
this_p = 0.0;
p[b][s][t] = this_p;
}
}
}
if (threadIdx.x == 0) {
......@@ -700,17 +715,21 @@ void mutual_information_backward_kernel(
--first_iter;
}
for (int i = first_iter; i >= 0; --i) {
int s = i,
t = i - threadIdx.x;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]);
{
int s = threadIdx.x;
if (s < block_S) {
for (int i = first_iter; i >= 0; --i) {
int t = i - s;
if (t >= 0) {
// The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] +
// p_grad[b][s][t + 1] * yderiv[b][s][t]
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]);
}
}
}
}
......
......@@ -12,13 +12,14 @@ def test_mutual_information_basic():
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 4
T = 5
S = 17
T = 17
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T).to(device) # log of an odds ratio
m = mutual_information_recursion(px, py, boundary)
m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary)
print("m = ", m)
print("exp(m) = ", m.exp())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment