grid.py 1.26 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from .utils.ffi import grid


rusty1s's avatar
rusty1s committed
4
def grid_cluster(pos, size, start=None, end=None):
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    """Voxel grid clustering algorithm, which overlays a regular grid of
    user-defined size over the point cloud and clusters all points within a
    voxel.

    Args:
        pos (Tensor): D-dimensional position of points.
        size (Tensor): Size of a voxel in each dimension.
        start (Tensor or int, optional): Start position of the grid (in each
            dimension). (default: :obj:`None`)
        end (Tensor or int, optional): End position of the grid (in each
            dimension). (default: :obj:`None`)

    Examples::

        >>> pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
        >>> size = torch.Tensor([5, 5])
        >>> cluster = grid_cluster(pos, size)
    """

rusty1s's avatar
rusty1s committed
24
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
    assert pos.size(1) == size.size(0), (
        'Last dimension of position tensor must have same size as size tensor')

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
    start = pos.t().min(dim=1)[0] if start is None else start
    end = pos.t().max(dim=1)[0] if end is None else end
    pos, end = pos - start, end - start

    size = size.type_as(pos)
    count = (end / size).long() + 1

    cluster = count.new(pos.size(0))
    grid(cluster, pos, size, count)

    return cluster