grid.py 3.08 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from __future__ import division

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5
from .utils import get_func, consecutive
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
10
def _preprocess(position, size, batch=None, start=None):
    size = size.type_as(position)

rusty1s's avatar
rusty1s committed
11
12
13
    # Allow one-dimensional positions.
    if position.dim() == 1:
        position = position.unsqueeze(-1)
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
    assert size.dim() == 1, 'Size tensor must be one-dimensional'
    assert position.size(-1) == size.size(-1), (
        'Last dimension of position tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
    # Translate to minimal positive positions if no start was passed.
    if start is None:
rusty1s's avatar
rusty1s committed
21
22
23
24
        min = []
        for i in range(position.size(-1)):
            min.append(position[:, i].min())
        position = position - position.new(min)
rusty1s's avatar
rusty1s committed
25
    else:
rusty1s's avatar
rusty1s committed
26
27
        assert start.numel() == size.numel(), (
            'Start tensor must have same size as size tensor')
rusty1s's avatar
rusty1s committed
28
        position = position - start.type_as(position)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
    # If given, append batch to position tensor.
    if batch is not None:
        batch = batch.unsqueeze(-1).type_as(position)
        assert position.size()[:-1] == batch.size()[:-1], (
            'Position tensor must have same size as batch tensor apart from '
            'the last dimension')
        position = torch.cat([batch, position], dim=-1)
        size = torch.cat([size.new(1).fill_(1), size], dim=-1)
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
    return position, size

rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
def _minimal_cluster_size(position, size):
rusty1s's avatar
rusty1s committed
43
44
45
46
    max = []
    for i in range(position.size(-1)):
        max.append(position[:, i].max())
    cluster_size = (size.new(max) / size).long() + 1
rusty1s's avatar
rusty1s committed
47
    return cluster_size
rusty1s's avatar
rusty1s committed
48
49


rusty1s's avatar
rusty1s committed
50
51
52
53
def _fixed_cluster_size(position, size, batch=None, end=None):
    if end is None:
        return _minimal_cluster_size(position, size)

rusty1s's avatar
rusty1s committed
54
    end = end.type_as(size)
rusty1s's avatar
rusty1s committed
55
    eps = 0.000001  # Simulate [start, end) interval.
rusty1s's avatar
rusty1s committed
56
57
58
59
    if batch is None:
        cluster_size = ((end / size).float() - eps).long() + 1
    else:
        cluster_size = ((end / size[1:]).float() - eps).long() + 1
rusty1s's avatar
rusty1s committed
60
61
        max_batch = cluster_size.new(1).fill_(batch.max() + 1)
        cluster_size = torch.cat([max_batch, cluster_size], dim=0)
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70

    return cluster_size


def _grid_cluster(position, size, cluster_size):
    C = cluster_size.prod()
    cluster = cluster_size.new(torch.Size(list(position.size())[:-1]))
    cluster = cluster.unsqueeze(dim=-1)

rusty1s's avatar
rusty1s committed
71
    func = get_func('grid', position)
rusty1s's avatar
rusty1s committed
72
73
    func(C, cluster, position, size, cluster_size)

rusty1s's avatar
rusty1s committed
74
    cluster = cluster.squeeze(dim=-1)
rusty1s's avatar
rusty1s committed
75
    return cluster, C
rusty1s's avatar
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
79
80
81
def sparse_grid_cluster(position, size, batch=None, start=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _minimal_cluster_size(position, size)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
82
    cluster, u = consecutive(cluster)
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
89
90
91
92
93
94

    if batch is None:
        return cluster
    else:
        batch = u / (C // cluster_size[0])
        return cluster, batch


def dense_grid_cluster(position, size, batch=None, start=None, end=None):
    position, size = _preprocess(position, size, batch, start)
    cluster_size = _fixed_cluster_size(position, size, batch, end)
    cluster, C = _grid_cluster(position, size, cluster_size)
rusty1s's avatar
rusty1s committed
95
    return cluster, C