"vscode:/vscode.git/clone" did not exist on "e5c6715003da433da5cf57d143fc5794f9d5c942"
utils.py 1.01 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch
from torch_unique import unique

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
10
11
from .._ext import ffi


def get_func(name, tensor):
    typename = type(tensor).__name__.replace('Tensor', '')
    cuda = 'cuda_' if tensor.is_cuda else ''
    func = getattr(ffi, 'cluster_{}_{}{}'.format(name, cuda, typename))
    return func
rusty1s's avatar
rusty1s committed
12
13


rusty1s's avatar
rusty1s committed
14
15
16
def get_type(max, cuda):
    if max <= 255:
        return torch.cuda.ByteTensor if cuda else torch.ByteTensor
rusty1s's avatar
rusty1s committed
17
    elif max <= 32767:  # pragma: no cover
rusty1s's avatar
rusty1s committed
18
        return torch.cuda.ShortTensor if cuda else torch.ShortTensor
rusty1s's avatar
rusty1s committed
19
    elif max <= 2147483647:  # pragma: no cover
rusty1s's avatar
rusty1s committed
20
        return torch.cuda.IntTensor if cuda else torch.IntTensor
rusty1s's avatar
rusty1s committed
21
    else:  # pragma: no cover
rusty1s's avatar
rusty1s committed
22
23
24
        return torch.cuda.LongTensor if cuda else torch.LongTensor


rusty1s's avatar
rusty1s committed
25
26
27
def consecutive(tensor):
    size = tensor.size()
    u = unique(tensor.view(-1))
rusty1s's avatar
rusty1s committed
28
29
30
31
32
    len = u[-1] + 1
    max = u.size(0)
    type = get_type(max, tensor.is_cuda)
    arg = type(len)
    arg[u] = torch.arange(0, max, out=type(max))
rusty1s's avatar
rusty1s committed
33
34
    tensor = arg[tensor.view(-1)]
    return tensor.view(size).long()