Commit 86e3a617 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some bugs..

parent 92814db2
......@@ -62,11 +62,13 @@ def _integrated_conv_backward_dispatcher(input: torch.Tensor,
if input.is_cuda:
if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
# Actually it's not a hard requirement that these things be contiguous.
return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous()))
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous(),
grad_output))
else:
return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
input, pos_add, pos_mul))
input, pos_add, pos_mul, grad_output))
......
......@@ -212,6 +212,8 @@ void integrated_conv_kernel(
if (relu > 0.0)
sum += relu * pos_mul_buf[pos_in_kernel];
}
// Sync threads because src_img_buf is also used above.
__syncthreads();
// Aggregate `sum` over threads
sum = tiled_warp_reduce_sum(threads_per_opixel, src_img_buf, sum);
if (threadIdx.x % threads_per_opixel == 0 && h < H && w < W) {
......
......@@ -19,11 +19,19 @@ def test_integrated_conv_zeros():
kW = 5
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)
input.requires_grad = True
pos_add.requires_grad = True
pos_mul.requires_grad = True
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul)
assert torch.allclose(output, output_ref)
output.sum().backward()
print("input_grad=", input.grad)
print("pos_add_grad=", pos_add.grad)
print("pos_mul_grad=", pos_mul.grad)
def test_integrated_conv_compare():
N = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment