Commit 629213e4 authored by rusty1s's avatar rusty1s
Browse files

better consecutive impl

parent 5f633fcb
...@@ -4,17 +4,11 @@ from torch_cluster.functions.utils.consecutive import consecutive ...@@ -4,17 +4,11 @@ from torch_cluster.functions.utils.consecutive import consecutive
def test_consecutive_cpu(): def test_consecutive_cpu():
vec = torch.LongTensor([0, 2, 3]) x = torch.LongTensor([0, 3, 2, 2, 3])
assert consecutive(vec).tolist() == [0, 1, 2] assert consecutive(x).tolist() == [0, 2, 1, 1, 2]
vec = torch.LongTensor([0, 3, 2, 2, 3])
assert consecutive(vec).tolist() == [0, 2, 1, 1, 2]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_consecutive_gpu(): # pragma: no cover def test_consecutive_gpu(): # pragma: no cover
vec = torch.cuda.LongTensor([0, 2, 3]) x = torch.cuda.LongTensor([0, 3, 2, 2, 3])
assert consecutive(vec).cpu().tolist() == [0, 1, 2] assert consecutive(x).cpu().tolist() == [0, 2, 1, 1, 2]
vec = torch.cuda.LongTensor([0, 3, 2, 2, 3])
assert consecutive(vec).cpu().tolist() == [0, 2, 1, 1, 2]
...@@ -14,13 +14,21 @@ def _get_type(max_value, cuda): ...@@ -14,13 +14,21 @@ def _get_type(max_value, cuda):
def consecutive(x): def consecutive(x):
size = x.size() initial_size = x.size()
# Compute unique vector.
u = unique(x.view(-1)) u = unique(x.view(-1))
len = u[-1] + 1
max = u.size(0) # Compute mask with mask[u[0]] = 0, mask[u[1]] = 1, ...
type = _get_type(max, x.is_cuda) # As mask can get very big (dependent on the maximum value in `x`, we want
arg = type(len) # to take the least possible amount of space on disk (`_get_type`).
arg[u] = torch.arange(0, max, out=type(max)) max_value = u[-1] + 1
x = arg[x.view(-1)] mask = _get_type(u.size(0), x.is_cuda)(max_value)
x = x.view(size).long() mask[u] = torch.arange(0, u.size(0), out=mask.new())
# Select the values in `mask` based on `x` and reshape to initial size.
x = mask[x.view(-1)]
x = x.view(initial_size)
x = x.long()
return x return x
...@@ -4,5 +4,5 @@ import torch ...@@ -4,5 +4,5 @@ import torch
def node_degree(index, num_nodes, out=None): def node_degree(index, num_nodes, out=None):
out = index.new(num_nodes) if out is None else out out = index.new(num_nodes) if out is None else out
zero = torch.zeros(num_nodes, out=out) zero = torch.zeros(num_nodes, out=out)
one = torch.ones(index.size(0), out=zero.new(index.size(0))) one = torch.ones(index.size(0), out=zero.new())
return zero.scatter_add_(0, index, one) return zero.scatter_add_(0, index, one)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment