grid.py 1.54 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
4
5
6


def grid_cluster(position, size, batch=None):
rusty1s's avatar
rusty1s committed
7
8
9
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
23
24

    # Translate to minimal positive positions.
rusty1s's avatar
rusty1s committed
25
    min = position.min(dim=-2, keepdim=True)[0]
rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
    position = position - min

    # Compute cluster count for each dimension.
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
rusty1s's avatar
rusty1s committed
32
    c_max = torch.floor(max.double() / size.double() + 1).long()
rusty1s's avatar
rusty1s committed
33
    c_max = torch.clamp(c_max, min=1)
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
    C = c_max.prod()

    # Generate cluster tensor.
    s = list(position.size())
    s[-1] = 1
    cluster = c_max.new(torch.Size(s))

    # Fill cluster tensor and reshape.
rusty1s's avatar
rusty1s committed
42
    size = size.type_as(position)
rusty1s's avatar
rusty1s committed
43
44
    func = get_func('grid', position)
    func(C, cluster, position, size, c_max)
rusty1s's avatar
rusty1s committed
45
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
46
    cluster = consecutive(cluster)
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
    return cluster