"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "3a30bf56dcc4e8b2b8e4cb7f4615954acb883155"
Commit 37778e99 authored by rusty1s's avatar rusty1s
Browse files

added new batch calculation

parent f5812714
...@@ -26,13 +26,13 @@ def test_grid_cluster_cpu(tensor): ...@@ -26,13 +26,13 @@ def test_grid_cluster_cpu(tensor):
output = grid_cluster(position.expand(2, 5, 2), size) output = grid_cluster(position.expand(2, 5, 2), size)
assert output.tolist() == expected.expand(2, 5).tolist() assert output.tolist() == expected.expand(2, 5).tolist()
expected = torch.LongTensor([0, 1, 3, 2, 4]) position = position.repeat(2, 1)
batch = torch.LongTensor([0, 0, 1, 1, 1]) batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
output = grid_cluster(position, size, batch) expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
output, reduced_batch = grid_cluster(position, size, batch)
assert output.tolist() == expected.tolist() assert output.tolist() == expected.tolist()
assert reduced_batch.tolist() == expected_batch.tolist()
output = grid_cluster(position.expand(2, 5, 2), size, batch.expand(2, 5))
assert output.tolist() == expected.expand(2, 5).tolist()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
...@@ -59,10 +59,10 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover ...@@ -59,10 +59,10 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output = grid_cluster(position.expand(2, 5, 2), size) output = grid_cluster(position.expand(2, 5, 2), size)
assert output.tolist() == expected.expand(2, 5).tolist() assert output.tolist() == expected.expand(2, 5).tolist()
expected = torch.LongTensor([0, 1, 3, 2, 4]) position = position.repeat(2, 1)
batch = torch.cuda.LongTensor([0, 0, 1, 1, 1]) batch = torch.cuda.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
output = grid_cluster(position, size, batch) expected = torch.LongTensor([0, 3, 1, 0, 2, 4, 7, 5, 4, 6])
expected_batch = torch.LongTensor([0, 0, 0, 0, 1, 1, 1, 1])
output, reduced_batch = grid_cluster(position, size, batch)
assert output.cpu().tolist() == expected.tolist() assert output.cpu().tolist() == expected.tolist()
assert reduced_batch.cpu().tolist() == expected_batch.tolist()
output = grid_cluster(position.expand(2, 5, 2), size, batch.expand(2, 5))
assert output.cpu().tolist() == expected.expand(2, 5).tolist()
...@@ -43,6 +43,10 @@ def grid_cluster(position, size, batch=None): ...@@ -43,6 +43,10 @@ def grid_cluster(position, size, batch=None):
func = get_func('grid', position) func = get_func('grid', position)
func(C, cluster, position, size, c_max) func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=-1) cluster = cluster.squeeze(dim=-1)
cluster = consecutive(cluster) cluster, u = consecutive(cluster)
return cluster if batch is None:
return cluster
else:
batch = (u / c_max[1:].prod()).long()
return cluster, batch
...@@ -22,7 +22,7 @@ def get_type(max, cuda): ...@@ -22,7 +22,7 @@ def get_type(max, cuda):
return torch.cuda.LongTensor if cuda else torch.LongTensor return torch.cuda.LongTensor if cuda else torch.LongTensor
def consecutive(tensor): def consecutive(tensor, return_batch=None):
size = tensor.size() size = tensor.size()
u = unique(tensor.view(-1)) u = unique(tensor.view(-1))
len = u[-1] + 1 len = u[-1] + 1
...@@ -31,4 +31,4 @@ def consecutive(tensor): ...@@ -31,4 +31,4 @@ def consecutive(tensor):
arg = type(len) arg = type(len)
arg[u] = torch.arange(0, max, out=type(max)) arg[u] = torch.arange(0, max, out=type(max))
tensor = arg[tensor.view(-1)] tensor = arg[tensor.view(-1)]
return tensor.view(size).long() return tensor.view(size).long(), u
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment