grid.py 2.96 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
11
12
13
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
    # Translate to minimal positive positions if no start was passed.
    if start is None:
        position = position - position.min(dim=-2, keepdim=True)[0]
    else:
        position = position - start
        assert position.min() >= 0, (
            'Passed origin resulting in unallowed negative positions')

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
    return position, size

rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
def _minimal_cluster_size(position, size):
40
41
42
    max = position.max(dim=0)[0]
    while max.dim() > 1:
        max = max.max(dim=0)[0]
rusty1s's avatar
rusty1s committed
43
44
    cluster_size = (max / size).long() + 1
    return cluster_size
rusty1s's avatar
rusty1s committed
45
46


rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def _fixed_cluster_size(position, size, batch=None, end=None):
    if end is None:
        return _minimal_cluster_size(position, size)

    eps = 0.000001  # Model [start, end).
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
        cluster_size = torch.cat([batch.max() + 1, cluster_size], dim=0)

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

rusty1s's avatar
rusty1s committed
66
    func = get_func('grid', position)
rusty1s's avatar
rusty1s committed
67
68
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
69
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
70
    return cluster, C
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
75
76
def sparse_grid_cluster(position, size, batch=None, start=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
77
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    if batch is None:
        return cluster
    else:
        batch = u / (C // cluster_size[0])
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, batch, end)
    cluster, C = _grid_cluster(position, size, cluster_size)

    if batch is None:
        return cluster, C
    else:
        C = C // cluster_size[0]
        return cluster, C