"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "4b5b4d3d373f9af9fd97dc9cf1b74ab3b4826d90"
Commit 5a8c1e3a authored by Daniel Povey's avatar Daniel Povey
Browse files

test with zeros working.

parent 12d3b03d
...@@ -26,7 +26,7 @@ except ImportError: ...@@ -26,7 +26,7 @@ except ImportError:
try: try:
import torch_integrated_conv_cuda import torch_integrated_conv_cuda
except ImportError: except ImportError:
if VERBOSE: if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cuda') print('Falling back to JIT compiling torch_integrated_conv_cuda')
...@@ -43,7 +43,7 @@ except ImportError: ...@@ -43,7 +43,7 @@ except ImportError:
def _integrated_conv_forward_dispather(input: torch.Tensor, def _integrated_conv_forward_dispatcher(input: torch.Tensor,
pos_add: torch.Tensor, pos_add: torch.Tensor,
pos_mul: torch.Tensor) -> torch.Tensor: pos_mul: torch.Tensor) -> torch.Tensor:
if input.is_cuda: if input.is_cuda:
......
...@@ -16,5 +16,6 @@ def test_integrated_conv_zeros(): ...@@ -16,5 +16,6 @@ def test_integrated_conv_zeros():
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul) output = integrated_conv(input, pos_add, pos_mul)
assert torch.allclose(input, output) assert torch.allclose(output, output_ref)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment