Commit 0a1d7c75 authored by rusty1s's avatar rusty1s
Browse files

consecutive impl

parent 1750d110
...@@ -12,7 +12,7 @@ setup_requires = ['pytest-runner', 'cffi'] ...@@ -12,7 +12,7 @@ setup_requires = ['pytest-runner', 'cffi']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_unique', name='torch_cluster',
version=__version__, version=__version__,
description='PyTorch Geometric Deep Learning Graph Cluster Algorithms', description='PyTorch Geometric Deep Learning Graph Cluster Algorithms',
author='Matthias Fey', author='Matthias Fey',
......
...@@ -43,7 +43,6 @@ def grid_cluster(position, size, batch=None): ...@@ -43,7 +43,6 @@ def grid_cluster(position, size, batch=None):
func = get_func('grid', position) func = get_func('grid', position)
func(C, cluster, position, size, c_max) func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=-1) cluster = cluster.squeeze(dim=-1)
cluster = consecutive(cluster) cluster = consecutive(cluster)
return cluster return cluster
...@@ -11,10 +11,24 @@ def get_func(name, tensor): ...@@ -11,10 +11,24 @@ def get_func(name, tensor):
return func return func
def get_type(max, cuda):
if max <= 255:
return torch.cuda.ByteTensor if cuda else torch.ByteTensor
elif max <= 32767:
return torch.cuda.ShortTensor if cuda else torch.ShortTensor
elif max <= 2147483647:
return torch.cuda.IntTensor if cuda else torch.IntTensor
else:
return torch.cuda.LongTensor if cuda else torch.LongTensor
def consecutive(tensor): def consecutive(tensor):
size = tensor.size() size = tensor.size()
u = unique(tensor.view(-1)) u = unique(tensor.view(-1))
arg = torch.ByteTensor(u[-1]) len = u[-1] + 1
arg[u] = torch.arange(0, u.size(0), out=torch.ByteTensor()) max = u.size(0)
type = get_type(max, tensor.is_cuda)
arg = type(len)
arg[u] = torch.arange(0, max, out=type(max))
tensor = arg[tensor.view(-1)] tensor = arg[tensor.view(-1)]
return tensor.view(size).long() return tensor.view(size).long()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment