Commit 10381dab authored by Daniel Povey's avatar Daniel Povey
Browse files

More tests...

parent 5a8c1e3a
...@@ -52,7 +52,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -52,7 +52,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
for (int kh = 0; kh < kH; kh++) { for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2; int src_h = h + kh - kH / 2;
for (int kw = 0; kw < kW; kw++) { for (int kw = 0; kw < kW; kw++) {
int src_w = h + kh - kH / 2; int src_w = w + kw - kW / 2;
scalar_t src = 0.0; scalar_t src = 0.0;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) && if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W)) static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
......
...@@ -294,8 +294,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input, ...@@ -294,8 +294,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
<< "; patchH,patchW=" << patchH << "," << "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch=" << patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch=" << num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch << std::endl; << num_blocks_batch << std::endl
<< ", threads_per_opixel=" << threads_per_opixel
<< ", threads_per_block=" << threads_per_block;
dim3 gridDim(C, num_blocks_patch, num_blocks_batch); dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block. // blockDim is scalar, just threads_per_block.
......
import random
import torch import torch
from torch_integrated_conv import integrated_conv from torch_integrated_conv import integrated_conv
...@@ -8,6 +9,9 @@ def test_integrated_conv_zeros(): ...@@ -8,6 +9,9 @@ def test_integrated_conv_zeros():
H = 3 H = 3
W = 4 W = 4
for device in [ torch.device('cpu'), torch.device('cuda:0') ]: for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
if device == torch.device('cuda:0') and not torch.cuda.is_available():
print("Warning: torch not available, not testing this part.")
continue
for dtype in [torch.float32, torch.float64]: for dtype in [torch.float32, torch.float64]:
print("device=", device, ", dtype=", dtype) print("device=", device, ", dtype=", dtype)
input = torch.zeros(N, 2 * C, H, W, device=device, dtype=dtype) input = torch.zeros(N, 2 * C, H, W, device=device, dtype=dtype)
...@@ -19,3 +23,76 @@ def test_integrated_conv_zeros(): ...@@ -19,3 +23,76 @@ def test_integrated_conv_zeros():
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype) output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul) output = integrated_conv(input, pos_add, pos_mul)
assert torch.allclose(output, output_ref) assert torch.allclose(output, output_ref)
def test_integrated_conv_compare():
N = 1
C = 2
H = 3
W = 4
if not torch.cuda.is_available():
print("Warning: torch not available, not testing this part.")
return
for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype)
input = torch.ones(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0')
input_cuda = input.to(device)
kH = 5
kW = 5
pos_add = torch.ones(C, kH, kW, dtype=dtype)
pos_mul = torch.ones(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device)
output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output)
print("output_cuda = ", output_cuda)
assert torch.allclose(output, output_cuda.to(torch.device('cpu')))
def test_integrated_conv_rand_compare():
for _ in range(30):
N = random.randint(1, 256)
C = random.randint(1, 64)
H = random.randint(1, 128)
W = random.randint(1, 128)
while N * C * H * W > 65535:
if N >= C and N >= H and N >= W:
N = N // 2
elif C >= H and C >= W:
C = C // 2
elif H >= W:
H = H // 2
else:
W = W // 2
if not torch.cuda.is_available():
print("Warning: torch not available, not testing this part.")
return
for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype)
input = torch.ones(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0')
input_cuda = input.to(device)
kH = random.randint(1, 10)
kW = random.randint(1, 10)
if kH % 2 == 0:
kH += 1
if kW % 2 == 0:
kW += 1
pos_add = torch.ones(C, kH, kW, dtype=dtype)
pos_mul = torch.ones(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device)
output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output)
print("output_cuda = ", output_cuda)
assert torch.allclose(output, output_cuda.to(torch.device('cpu')))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment