Commit 0a1d7c75 authored by rusty1s's avatar rusty1s
Browse files

consecutive impl

parent 1750d110
......@@ -12,7 +12,7 @@ setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_unique',
name='torch_cluster',
version=__version__,
description='PyTorch Geometric Deep Learning Graph Cluster Algorithms',
author='Matthias Fey',
......
......@@ -43,7 +43,6 @@ def grid_cluster(position, size, batch=None):
func = get_func('grid', position)
func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=-1)
cluster = consecutive(cluster)
return cluster
......@@ -11,10 +11,24 @@ def get_func(name, tensor):
return func
def get_type(max, cuda):
if max <= 255:
return torch.cuda.ByteTensor if cuda else torch.ByteTensor
elif max <= 32767:
return torch.cuda.ShortTensor if cuda else torch.ShortTensor
elif max <= 2147483647:
return torch.cuda.IntTensor if cuda else torch.IntTensor
else:
return torch.cuda.LongTensor if cuda else torch.LongTensor
def consecutive(tensor):
size = tensor.size()
u = unique(tensor.view(-1))
arg = torch.ByteTensor(u[-1])
arg[u] = torch.arange(0, u.size(0), out=torch.ByteTensor())
len = u[-1] + 1
max = u.size(0)
type = get_type(max, tensor.is_cuda)
arg = type(len)
arg[u] = torch.arange(0, max, out=type(max))
tensor = arg[tensor.view(-1)]
return tensor.view(size).long()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment