grid.py 1.2 KB
Newer Older
rusty1s's avatar
new api  
rusty1s committed
1
import torch
2
import torch_cluster.grid_cpu
rusty1s's avatar
new api  
rusty1s committed
3
4

if torch.cuda.is_available():
5
    import torch_cluster.grid_cuda
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
def grid_cluster(pos, size, start=None, end=None):
rusty1s's avatar
rusty1s committed
9
10
    """A clustering algorithm, which overlays a regular grid of user-defined
    size over a point cloud and clusters all points within a voxel.
rusty1s's avatar
rusty1s committed
11
12
13
14

    Args:
        pos (Tensor): D-dimensional position of points.
        size (Tensor): Size of a voxel in each dimension.
rusty1s's avatar
new api  
rusty1s committed
15
        start (Tensor, optional): Start position of the grid (in each
rusty1s's avatar
rusty1s committed
16
            dimension). (default: :obj:`None`)
rusty1s's avatar
new api  
rusty1s committed
17
        end (Tensor, optional): End position of the grid (in each
rusty1s's avatar
rusty1s committed
18
19
            dimension). (default: :obj:`None`)

rusty1s's avatar
docs  
rusty1s committed
20
21
    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
    Examples::

        >>> pos = torch.Tensor([[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]])
        >>> size = torch.Tensor([5, 5])
        >>> cluster = grid_cluster(pos, size)
    """

rusty1s's avatar
rusty1s committed
29
    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
rusty1s's avatar
rusty1s committed
30
31
32
    start = pos.t().min(dim=1)[0] if start is None else start
    end = pos.t().max(dim=1)[0] if end is None else end

33
34
35
36
37
    if pos.is_cuda:
        op = torch_cluster.grid_cuda
    else:
        op = torch_cluster.grid_cpu

rusty1s's avatar
rusty1s committed
38
    cluster = op.grid(pos, size, start, end)
rusty1s's avatar
rusty1s committed
39
40

    return cluster