Commit e95d789b authored by rusty1s's avatar rusty1s
Browse files

test one-dimensional positions

parent 511f78e4
......@@ -7,6 +7,13 @@ from .utils import tensors, Tensor
@pytest.mark.parametrize('tensor', tensors)
def test_grid_cluster_cpu(tensor):
position = Tensor(tensor, [0, 9, 2, 8, 3])
size = torch.LongTensor([5])
expected = torch.LongTensor([0, 1, 0, 1, 0])
output = grid_cluster(position, size)
assert output.tolist() == expected.tolist()
position = Tensor(tensor, [[0, 0], [9, 9], [2, 8], [2, 2], [8, 3]])
size = torch.LongTensor([5, 5])
expected = torch.LongTensor([0, 3, 1, 0, 2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment