Commit 37778e99 authored by rusty1s's avatar rusty1s
Browse files

added new batch calculation

parent f5812714
......@@ -26,13 +26,13 @@ def test_grid_cluster_cpu(tensor):
output = grid_cluster(position.expand(2, 5, 2), size)
assert output.tolist() == expected.expand(2, 5).tolist()
expected = torch.LongTensor([0, 1, 3, 2, 4])
batch = torch.LongTensor([0, 0, 1, 1, 1])
output = grid_cluster(position, size, batch)
position = position.repeat(2, 1)
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
output, reduced_batch = grid_cluster(position, size, batch)
assert output.tolist() == expected.tolist()
output = grid_cluster(position.expand(2, 5, 2), size, batch.expand(2, 5))
assert output.tolist() == expected.expand(2, 5).tolist()
assert reduced_batch.tolist() == expected_batch.tolist()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
......@@ -59,10 +59,10 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output = grid_cluster(position.expand(2, 5, 2), size)
assert output.tolist() == expected.expand(2, 5).tolist()
expected = torch.LongTensor([0, 1, 3, 2, 4])
batch = torch.cuda.LongTensor([0, 0, 1, 1, 1])
output = grid_cluster(position, size, batch)
position = position.repeat(2, 1)
batch = torch.cuda.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
output, reduced_batch = grid_cluster(position, size, batch)
assert output.cpu().tolist() == expected.tolist()
output = grid_cluster(position.expand(2, 5, 2), size, batch.expand(2, 5))
assert output.cpu().tolist() == expected.expand(2, 5).tolist()
assert reduced_batch.cpu().tolist() == expected_batch.tolist()
......@@ -43,6 +43,10 @@ def grid_cluster(position, size, batch=None):
func = get_func('grid', position)
func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=-1)
cluster = consecutive(cluster)
cluster, u = consecutive(cluster)
return cluster
if batch is None:
return cluster
else:
batch = (u / c_max[1:].prod()).long()
return cluster, batch
......@@ -22,7 +22,7 @@ def get_type(max, cuda):
return torch.cuda.LongTensor if cuda else torch.LongTensor
def consecutive(tensor):
def consecutive(tensor, return_batch=None):
size = tensor.size()
u = unique(tensor.view(-1))
len = u[-1] + 1
......@@ -31,4 +31,4 @@ def consecutive(tensor):
arg = type(len)
arg[u] = torch.arange(0, max, out=type(max))
tensor = arg[tensor.view(-1)]
return tensor.view(size).long()
return tensor.view(size).long(), u
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment