Commit 53b31903 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some bugs..

parent 17b18990
......@@ -68,7 +68,7 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
print(f"Warning: possible excesssive roundoff in mutual information backward "
f"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
......@@ -106,7 +106,7 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
ans = _mutual_information_forward_dispatcher(px, py, boundary, p)
#print(f"p = {p}, boundary = {boundary}")
print(f"p = {p}, boundary = {boundary}")
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundary, p)
......
......@@ -97,8 +97,8 @@ torch::Tensor mutual_information_cpu(torch::Tensor px,
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
s_end = S;
t_end = T;
}
p_a[b][s_begin][t_begin] = 0.0;
......
......@@ -208,16 +208,16 @@ void mutual_information_kernel(
t_in_block = i % BLOCK_SIZE,
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// comparing as unsigned int makes sure the index is nonnegative.
// Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px and
// py values that are outside the proper boundaries that we need, but
// the corresponding p_buf values will end up being 0 so this won't matter.
scalar_t this_px = 0.0;
if (static_cast<unsigned int>(s - 1) < static_cast<unsigned int>(s_end) &&
t <= t_end)
if (s > s_begin && s <= s_end && t <= t_end)
this_px = exp(px[b][s - 1][t]);
px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = 0.0;
if (static_cast<unsigned int>(t - 1) < static_cast<unsigned int>(t_end) &&
s <= s_end)
if (t > t_begin && t <= t_end && s <= s_end)
this_py = exp(py[b][s][t - 1]);
py_buf[s_in_block][t_in_block] = this_py;
}
......@@ -234,32 +234,28 @@ void mutual_information_kernel(
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
if (s >= s_begin && s <= s_end &&
t >= t_begin && t <= t_end)
this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
}
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} else {
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
// Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
if (static_cast<unsigned int>(int(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) {
int s_in_p_buf = 0,
t_in_p_buf = (int)threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY;
if (static_cast<unsigned int>(s) <= static_cast<unsigned int>(s_end) &&
static_cast<unsigned int>(t) <= static_cast<unsigned int>(t_end)) {
if (s >= s_begin && s <= s_end &&
t >= t_begin && t <= t_end)
this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
}
p_buf[s_in_p_buf][t_in_p_buf] = this_p;
}
}
__syncthreads();
......@@ -421,10 +417,10 @@ void mutual_information_kernel(
static_cast<unsigned int>(block_T)) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
float p_s1 = (s == 0 ? -INFINITY : p[b][s - 1][t]),
this_px = (s == 0 ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == 0 ? -INFINITY : p[b][s][t - 1]),
this_py = (t == 0 ? -INFINITY : py[b][s][t - 1]);
float p_s1 = (s == s_begin ? -INFINITY : p[b][s - 1][t]),
this_px = (s == s_begin ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == t_begin ? -INFINITY : p[b][s][t - 1]),
this_py = (t == t_begin ? -INFINITY : py[b][s][t - 1]);
float this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
......@@ -433,6 +429,7 @@ void mutual_information_kernel(
}
}
}
__syncwarp();
if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
......@@ -650,6 +647,7 @@ void mutual_information_backward_kernel(
this_py = py[b][s][t];
py_buf[s_in_block][t_in_block] = this_py;
}
__syncthreads();
// load p. We could use BLOCK_SIZE + 1 here, but we use + 8 to hopefully keep
// reads more aligned.
......@@ -669,6 +667,8 @@ void mutual_information_backward_kernel(
p_buf[s_in_block][t_in_block] = this_p;
}
__syncthreads();
// Set xderiv and yderiv; see (eq. 4) and (eq. 5).
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing
......@@ -687,6 +687,8 @@ void mutual_information_backward_kernel(
py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
}
__syncthreads();
// Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't
// need to load the top-right corner [block_S][block_T]; that location will
......@@ -714,6 +716,8 @@ void mutual_information_backward_kernel(
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
}
__syncthreads();
// The highest-numbered value in p_buf that we need (corresponding,
// of course, to p_grad), is:
// p_buf[block_S - 1][block_T - 1],
......
......@@ -35,8 +35,11 @@ def test_mutual_information_basic():
s_end = random.randint(s_begin + 1, S)
t_end = random.randint(t_begin + 1, T)
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
dtype=torch.int64, device=device)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly or not.
if random.random() < 0.5:
......@@ -84,78 +87,7 @@ def test_mutual_information_basic():
if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-05, rtol=1.0e-04):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-05, rtol=1.0e-04):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0
def test_mutual_information_deriv():
print("Running test_mutual_information_basic()")
for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.1)
random_py = (random.random() < 0.1)
big_px = (random.random() < 0.1)
big_py = (random.random() < 0.1)
print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}, big_px={big_px}, big_py={big_py}")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 14
T = 14
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if big_px:
px += 15.0
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
#print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
m_vals.append(m.to('cpu'))
if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-05, rtol=1.0e-04):
print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
assert 0
if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-05, rtol=1.0e-04):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-05, rtol=1.0e-04):
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-05, rtol=1.0e-03):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0
......@@ -164,5 +96,6 @@ def test_mutual_information_deriv():
if __name__ == "__main__":
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic()
#test_mutual_information_deriv()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment