grid.py 1.12 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
import torch

from .utils import get_func


def grid_cluster(position, size, batch=None):
    # TODO: Check types and sizes
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    if batch is not None:
        batch = batch.type_as(position)
        position = torch.cat([position, batch], dim=position.dim() - 1)
        size = torch.cat([size, size.new(1).fill_(1)], dim=0)

    dim = position.dim()

    # Allow one-dimensional positions.
    if dim == 1:
        position = position.unsqueeze(1)
        dim += 1

    # Translate to minimal positive positions.
    min = position.min(dim=dim - 2, keepdim=True)[0]
    position = position - min

    # Compute cluster count for each dimension.
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
    c_max = torch.ceil(max / size.type_as(max)).long()
rusty1s's avatar
rusty1s committed
30
    c_max = torch.clamp(c_max, min=1)
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39
40
41
42
43
    C = c_max.prod()

    # Generate cluster tensor.
    s = list(position.size())
    s[-1] = 1
    cluster = c_max.new(torch.Size(s))

    # Fill cluster tensor and reshape.
    func = get_func('grid', position)
    func(C, cluster, position, size, c_max)
    cluster = cluster.squeeze(dim=dim - 1)

    return cluster