Commit 1750d110 authored by rusty1s's avatar rusty1s
Browse files

first consecutive impl

parent 43fde05a
import torch
from .utils import get_func
from .utils import get_func, consecutive
def grid_cluster(position, size, batch=None):
......@@ -44,4 +44,6 @@ def grid_cluster(position, size, batch=None):
func(C, cluster, position, size, c_max)
cluster = cluster.squeeze(dim=-1)
cluster = consecutive(cluster)
return cluster
import torch
from torch_unique import unique
from .._ext import ffi
......@@ -6,3 +9,12 @@ def get_func(name, tensor):
cuda = 'cuda_' if tensor.is_cuda else ''
func = getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
return func
def consecutive(tensor):
size = tensor.size()
u = unique(tensor.view(-1))
arg = torch.ByteTensor(u[-1])
arg[u] = torch.arange(0, u.size(0), out=torch.ByteTensor())
tensor = arg[tensor.view(-1)]
return tensor.view(size).long()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment