ball_query.py 1.47 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
3
4
5
6
7
8
import torch
from torch.autograd import Function

from . import ball_query_ext


class BallQuery(Function):
zhangwenwei's avatar
zhangwenwei committed
9
    """Ball Query.
wuyuefeng's avatar
wuyuefeng committed
10
11
12
13
14

    Find nearby points in spherical space.
    """

    @staticmethod
15
16
    def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
                xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
wuyuefeng's avatar
wuyuefeng committed
17
18
19
        """forward.

        Args:
20
21
            min_radius (float): minimum radius of the balls.
            max_radius (float): maximum radius of the balls.
wuyuefeng's avatar
wuyuefeng committed
22
23
24
25
26
            sample_num (int): maximum number of features in the balls.
            xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            center_xyz (Tensor): (B, npoint, 3) centers of the ball query.

        Returns:
27
            Tensor: (B, npoint, nsample) tensor with the indices of
wuyuefeng's avatar
wuyuefeng committed
28
29
30
31
                the features that form the query balls.
        """
        assert center_xyz.is_contiguous()
        assert xyz.is_contiguous()
32
        assert min_radius < max_radius
wuyuefeng's avatar
wuyuefeng committed
33
34
35
36
37

        B, N, _ = xyz.size()
        npoint = center_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_()

38
39
        ball_query_ext.ball_query_wrapper(B, N, npoint, min_radius, max_radius,
                                          sample_num, center_xyz, xyz, idx)
wuyuefeng's avatar
wuyuefeng committed
40
41
42
43
44
45
46
47
48
        ctx.mark_non_differentiable(idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply