utils.py 543 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch_unique import unique

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
from .._ext import ffi


def get_func(name, tensor):
    typename = type(tensor).__name__.replace('Tensor', '')
    cuda = 'cuda_' if tensor.is_cuda else ''
    func = getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
    return func
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
19
20


def consecutive(tensor):
    size = tensor.size()
    u = unique(tensor.view(-1))
    arg = torch.ByteTensor(u[-1])
    arg[u] = torch.arange(0, u.size(0), out=torch.ByteTensor())
    tensor = arg[tensor.view(-1)]
    return tensor.view(size).long()