Commit 5a8c1e3a authored by Daniel Povey's avatar Daniel Povey
Browse files

test with zeros working.

parent 12d3b03d
......@@ -26,7 +26,7 @@ except ImportError:
try:
import torch_integrated_conv_cuda
import torch_integrated_conv_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cuda')
......@@ -43,7 +43,7 @@ except ImportError:
def _integrated_conv_forward_dispather(input: torch.Tensor,
def _integrated_conv_forward_dispatcher(input: torch.Tensor,
pos_add: torch.Tensor,
pos_mul: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
......
......@@ -16,5 +16,6 @@ def test_integrated_conv_zeros():
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul)
assert torch.allclose(input, output)
assert torch.allclose(output, output_ref)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment