Commit 5fc62fa6 authored by Daniel Povey's avatar Daniel Povey
Browse files

Get forward tests to work

parent 10381dab
...@@ -37,7 +37,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -37,7 +37,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
auto input_a = input.accessor<scalar_t, 4>(), auto input_a = input.accessor<scalar_t, 4>(),
output_a = output.accessor<scalar_t, 4>(); output_a = output.accessor<scalar_t, 4>();
auto pos_add_a = pos_add.accessor<scalar_t, 3>(), auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_add.accessor<scalar_t, 3>(); pos_mul_a = pos_mul.accessor<scalar_t, 3>();
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
......
...@@ -294,9 +294,10 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input, ...@@ -294,9 +294,10 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
<< "; patchH,patchW=" << patchH << "," << "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch=" << patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch=" << num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch << std::endl << num_blocks_batch
<< ", threads_per_opixel=" << threads_per_opixel << ", threads_per_opixel=" << threads_per_opixel
<< ", threads_per_block=" << threads_per_block; << ", threads_per_block=" << threads_per_block
<< std::endl;
dim3 gridDim(C, num_blocks_patch, num_blocks_batch); dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block. // blockDim is scalar, just threads_per_block.
......
...@@ -35,14 +35,15 @@ def test_integrated_conv_compare(): ...@@ -35,14 +35,15 @@ def test_integrated_conv_compare():
return return
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype) print("dtype=", dtype)
input = torch.ones(N, 2 * C, H, W, dtype=dtype) input = torch.randn(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0') device = torch.device('cuda:0')
input_cuda = input.to(device) input_cuda = input.to(device)
kH = 5 kH = 5
kW = 5 kW = 5
pos_add = torch.ones(C, kH, kW, dtype=dtype) pos_add = torch.randn(C, kH, kW, dtype=dtype)
pos_mul = torch.ones(C, kH, kW, dtype=dtype) pos_mul = torch.randn(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device) pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device) pos_mul_cuda = pos_mul.to(device)
...@@ -50,7 +51,11 @@ def test_integrated_conv_compare(): ...@@ -50,7 +51,11 @@ def test_integrated_conv_compare():
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda) output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output) print("output = ", output)
print("output_cuda = ", output_cuda) print("output_cuda = ", output_cuda)
assert torch.allclose(output, output_cuda.to(torch.device('cpu'))) diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs)
assert torch.allclose(output, output_cuda.to(torch.device('cpu')),
atol=1.0e-05)
def test_integrated_conv_rand_compare(): def test_integrated_conv_rand_compare():
...@@ -76,7 +81,7 @@ def test_integrated_conv_rand_compare(): ...@@ -76,7 +81,7 @@ def test_integrated_conv_rand_compare():
return return
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype) print("dtype=", dtype)
input = torch.ones(N, 2 * C, H, W, dtype=dtype) input = torch.randn(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0') device = torch.device('cuda:0')
input_cuda = input.to(device) input_cuda = input.to(device)
...@@ -86,13 +91,20 @@ def test_integrated_conv_rand_compare(): ...@@ -86,13 +91,20 @@ def test_integrated_conv_rand_compare():
kH += 1 kH += 1
if kW % 2 == 0: if kW % 2 == 0:
kW += 1 kW += 1
pos_add = torch.ones(C, kH, kW, dtype=dtype) pos_add = torch.randn(C, kH, kW, dtype=dtype)
pos_mul = torch.ones(C, kH, kW, dtype=dtype) pos_mul = torch.randn(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device) pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device) pos_mul_cuda = pos_mul.to(device)
output = integrated_conv(input, pos_add, pos_mul) output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda) output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output)
print("output_cuda = ", output_cuda) diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
assert torch.allclose(output, output_cuda.to(torch.device('cpu'))) abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs)
if not torch.allclose(output, output_cuda.to(torch.device('cpu')),
atol=1.0e-05):
print("output = ", output)
print("output_cuda = ", output_cuda)
assert 0, "outputs differ"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment