grid.py 1.59 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
4
5
6


def grid_cluster(position, size, batch=None):
rusty1s's avatar
rusty1s committed
7
8
9
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
13
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
23
24

    # Translate to minimal positive positions.
25
26
    p_min = position.min(dim=-2, keepdim=True)[0]
    position = position - p_min
rusty1s's avatar
rusty1s committed
27

28
29
30
31
    # Compute maximal position for each dimension.
    p_max = position.max(dim=0)[0]
    while p_max.dim() > 1:
        p_max = p_max.max(dim=0)[0]
rusty1s's avatar
rusty1s committed
32
33

    # Generate cluster tensor.
34
35
    s = list(position.size())[:-1] + [1]
    cluster = size.new(torch.Size(s)).long()
rusty1s's avatar
rusty1s committed
36
37

    # Fill cluster tensor and reshape.
rusty1s's avatar
rusty1s committed
38
    size = size.type_as(position)
rusty1s's avatar
rusty1s committed
39
    func = get_func('grid', position)
40
    C = func(cluster, position, size, p_max)
rusty1s's avatar
rusty1s committed
41
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
42
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
46
    if batch is None:
        return cluster
    else:
47
48
        print(p_max.tolist(), size.tolist(), C)
        batch = (u / C).long()
rusty1s's avatar
rusty1s committed
49
        return cluster, batch