Commit 86e3a617 authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix some bugs..

parent 92814db2
...@@ -62,11 +62,13 @@ def _integrated_conv_backward_dispatcher(input: torch.Tensor, ...@@ -62,11 +62,13 @@ def _integrated_conv_backward_dispatcher(input: torch.Tensor,
if input.is_cuda: if input.is_cuda:
if torch_integrated_conv_cuda is None: if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module') raise EnvironmentError(f'Failed to load native CUDA module')
# Actually it's not a hard requirement that these things be contiguous.
return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda( return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous())) input.contiguous(), pos_add.contiguous(), pos_mul.contiguous(),
grad_output))
else: else:
return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu( return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
input, pos_add, pos_mul)) input, pos_add, pos_mul, grad_output))
......
...@@ -212,6 +212,8 @@ void integrated_conv_kernel( ...@@ -212,6 +212,8 @@ void integrated_conv_kernel(
if (relu > 0.0) if (relu > 0.0)
sum += relu * pos_mul_buf[pos_in_kernel]; sum += relu * pos_mul_buf[pos_in_kernel];
} }
// Sync threads because src_img_buf is also used above.
__syncthreads();
// Aggregate `sum` over threads // Aggregate `sum` over threads
sum = tiled_warp_reduce_sum(threads_per_opixel, src_img_buf, sum); sum = tiled_warp_reduce_sum(threads_per_opixel, src_img_buf, sum);
if (threadIdx.x % threads_per_opixel == 0 && h < H && w < W) { if (threadIdx.x % threads_per_opixel == 0 && h < H && w < W) {
......
...@@ -19,11 +19,19 @@ def test_integrated_conv_zeros(): ...@@ -19,11 +19,19 @@ def test_integrated_conv_zeros():
kW = 5 kW = 5
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)
input.requires_grad = True
pos_add.requires_grad = True
pos_mul.requires_grad = True
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype) output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul) output = integrated_conv(input, pos_add, pos_mul)
assert torch.allclose(output, output_ref) assert torch.allclose(output, output_ref)
output.sum().backward()
print("input_grad=", input.grad)
print("pos_add_grad=", pos_add.grad)
print("pos_mul_grad=", pos_mul.grad)
def test_integrated_conv_compare(): def test_integrated_conv_compare():
N = 1 N = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment