Commit 6a2e1a08 authored by rusty1s's avatar rusty1s
Browse files

rename

parent fb3d8a81
...@@ -14,7 +14,7 @@ def test_grid_cluster_cpu(tensor): ...@@ -14,7 +14,7 @@ def test_grid_cluster_cpu(tensor):
assert output.tolist() == expected.tolist() assert output.tolist() == expected.tolist()
expected = torch.LongTensor([0, 1]) expected = torch.LongTensor([0, 1])
output, _ = grid_cluster(position, size, offset=0) output, _ = grid_cluster(position, size, origin=0)
assert output.tolist() == expected.tolist() assert output.tolist() == expected.tolist()
position = Tensor(tensor, [0, 17, 2, 8, 3]) position = Tensor(tensor, [0, 17, 2, 8, 3])
...@@ -63,7 +63,7 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover ...@@ -63,7 +63,7 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
assert output.cpu().tolist() == expected.tolist() assert output.cpu().tolist() == expected.tolist()
expected = torch.LongTensor([0, 1]) expected = torch.LongTensor([0, 1])
output, _ = grid_cluster(position, size, offset=0) output, _ = grid_cluster(position, size, origin=0)
assert output.cpu().tolist() == expected.tolist() assert output.cpu().tolist() == expected.tolist()
position = Tensor(tensor, [0, 17, 2, 8, 3]).cuda() position = Tensor(tensor, [0, 17, 2, 8, 3]).cuda()
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from .utils import get_func, consecutive from .utils import get_func, consecutive
def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False): def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
# Allow one-dimensional positions. # Allow one-dimensional positions.
if position.dim() == 1: if position.dim() == 1:
position = position.unsqueeze(-1) position = position.unsqueeze(-1)
...@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False): ...@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
position = torch.cat([batch, position], dim=-1) position = torch.cat([batch, position], dim=-1)
size = torch.cat([size.new(1).fill_(1), size], dim=-1) size = torch.cat([size.new(1).fill_(1), size], dim=-1)
# Translate to minimal positive positions if no offset is passed. # Translate to minimal positive positions if no origin was passed.
if offset is None: if origin is None:
min = position.min(dim=-2, keepdim=True)[0] min = position.min(dim=-2, keepdim=True)[0]
position = position - min position = position - min
else: else:
position = position + offset position = position + origin
assert position.min() >= 0, ( assert position.min() >= 0, (
'Passed offset resulting in unallowed negative positions') 'Passed origin resulting in unallowed negative positions')
# Compute cluster count for each dimension. # Compute cluster count for each dimension.
max = position.max(dim=0)[0] max = position.max(dim=0)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment