Commit 9ebcf9d5 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix more bugs..

parent 85c97136
......@@ -192,11 +192,11 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
......@@ -212,19 +212,19 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t >= t_begin; --t) {
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
for (int s = s_end; s >= s_begin; --s) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][s_begin];
p_a[b][s - 1][t_begin] += this_p_grad;
px_a[b][s - 1][t_begin] += this_p_grad;
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
......@@ -232,7 +232,7 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_a[b][s_begin][t_begin] / ans_grad_a[b];
if (grad_ratio - 1.0 > 0.01) {
if (fabs(grad_ratio - 1.0) > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
......
......@@ -750,6 +750,8 @@ void mutual_information_backward_kernel(
}
}
__syncthreads();
// Write out p_grad, px_grad and py_grad.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE,
......@@ -881,7 +883,7 @@ mutual_information_backward_cuda(torch::Tensor px,
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
// num_threads and num_blocks and BLOCK_SIZE can be tuned.
......
......@@ -9,20 +9,35 @@ from torch_mutual_information import mutual_information_recursion
def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 33
T = 33
S = 14
T = 14
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
px.requires_grad = True
py.requires_grad = True
m = mutual_information_recursion(px, py, None)
#m = mutual_information_recursion(px, py, boundary)
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
print("m = ", m, ", size = ", m.shape)
print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
print("px_grad = ", px.grad)
print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
if not torch.allclose(px_grads[0], px_grads[1]):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1]):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment