grid.py 3.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
11
12
13
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
    # Translate to minimal positive positions if no start was passed.
    if start is None:
rusty1s's avatar
rusty1s committed
21
22
23
24
        min = []
        for i in range(position.size(-1)):
            min.append(position[:, i].min())
        position = position - position.new(min)
rusty1s's avatar
rusty1s committed
25
    else:
rusty1s's avatar
rusty1s committed
26
27
        assert start.numel() == size.numel(), (
            'Start tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
28
        position = position - start.type_as(position)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
    return position, size

rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
def _minimal_cluster_size(position, size):
rusty1s's avatar
rusty1s committed
43
44
45
46
    max = []
    for i in range(position.size(-1)):
        max.append(position[:, i].max())
    cluster_size = (size.new(max) / size).long() + 1
rusty1s's avatar
rusty1s committed
47
    return cluster_size
rusty1s's avatar
rusty1s committed
48
49


rusty1s's avatar
rusty1s committed
50
51
52
53
def _fixed_cluster_size(position, size, batch=None, end=None):
    if end is None:
        return _minimal_cluster_size(position, size)

rusty1s's avatar
rusty1s committed
54
55
56
    assert end.numel() == size.numel(), (
        'End tensor must have same size as size tensor')

rusty1s's avatar
rusty1s committed
57
    end = end.type_as(size)
rusty1s's avatar
rusty1s committed
58
    eps = 0.000001  # Simulate [start, end) interval.
rusty1s's avatar
rusty1s committed
59
60
61
62
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
rusty1s's avatar
rusty1s committed
63
64
        max_batch = cluster_size.new(1).fill_(batch.max() + 1)
        cluster_size = torch.cat([max_batch, cluster_size], dim=0)
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72
73

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

rusty1s's avatar
rusty1s committed
74
    func = get_func('grid', position)
rusty1s's avatar
rusty1s committed
75
76
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
77
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
78
    return cluster, C
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
82
83
84
def sparse_grid_cluster(position, size, batch=None, start=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
85
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
86
87
88
89
90
91
92
93
94
95
96
97

    if batch is None:
        return cluster
    else:
        batch = u / (C // cluster_size[0])
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, batch, end)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
98
    return cluster, C