"examples/python-rag-newssummary/utils.py" did not exist on "08b0e04f40025c41dea6a40fa434acf9e2672d4a"
ball_query.py 1.28 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
7
import torch
from torch.autograd import Function

from . import ball_query_ext


class BallQuery(Function):
zhangwenwei's avatar
zhangwenwei committed
8
    """Ball Query.
wuyuefeng's avatar
wuyuefeng committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    Find nearby points in spherical space.
    """

    @staticmethod
    def forward(ctx, radius: float, sample_num: int, xyz: torch.Tensor,
                center_xyz: torch.Tensor) -> torch.Tensor:
        """forward.

        Args:
            radius (float): radius of the balls.
            sample_num (int): maximum number of features in the balls.
            xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            center_xyz (Tensor): (B, npoint, 3) centers of the ball query.

        Returns:
            Tensor: (B, npoint, nsample) tensor with the indicies of
                the features that form the query balls.
        """
        assert center_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = center_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_()

        ball_query_ext.ball_query_wrapper(B, N, npoint, radius, sample_num,
                                          center_xyz, xyz, idx)
        ctx.mark_non_differentiable(idx)
        return idx

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply