"vscode:/vscode.git/clone" did not exist on "a16957159e24e53df0a40bfadbcd0a61b4b3ea8c"
Commit 99baecdf authored by rusty1s's avatar rusty1s
Browse files

added gpu tests

parent 72b1ce14
...@@ -29,4 +29,20 @@ def test_grid_cluster_cpu(tensor): ...@@ -29,4 +29,20 @@ def test_grid_cluster_cpu(tensor):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor', tensors) @pytest.mark.parametrize('tensor', tensors)
def test_grid_cluster_gpu(tensor): # pragma: no cover def test_grid_cluster_gpu(tensor): # pragma: no cover
pass position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]]).cuda()
size = torch.cuda.LongTensor([5, 5])
expected = torch.LongTensor([0, 3, 1, 0, 2])
output = grid_cluster(position, size)
assert output.cpu().tolist() == expected.tolist()
output = grid_cluster(position.expand(2, 5, 2), size)
assert output.cpu().tolist() == expected.expand(2, 5).tolist()
expected = torch.LongTensor([0, 1, 3, 2, 4])
batch = torch.cuda.LongTensor([0, 0, 1, 1, 1])
output = grid_cluster(position, size, batch)
assert output.cpu().tolist() == expected.tolist()
output = grid_cluster(position.expand(2, 5, 2), size, batch.expand(2, 5))
assert output.cpu().tolist() == expected.expand(2, 5).tolist()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment