points_in_boxes.py 1.62 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
5
6
import torch

from . import roiaware_pool3d_ext


def points_in_boxes_gpu(points, boxes):
wuyuefeng's avatar
wuyuefeng committed
7
8
    """find points in boxes (CUDA)

wuyuefeng's avatar
wuyuefeng committed
9
10
11
12
13
    Args:
        points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR coordinate
        boxes (torch.Tensor): [B, T, 7],
            num_valid_boxes <= T, [x, y, z, w, l, h, ry] in LiDAR coordinate,
            (x, y, z) is the bottom center
wuyuefeng's avatar
wuyuefeng committed
14

wuyuefeng's avatar
wuyuefeng committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    Returns:
        box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
    """
    assert boxes.shape[0] == points.shape[0]
    assert boxes.shape[2] == 7
    batch_size, num_points, _ = points.shape

    box_idxs_of_pts = points.new_zeros((batch_size, num_points),
                                       dtype=torch.int).fill_(-1)
    roiaware_pool3d_ext.points_in_boxes_gpu(boxes.contiguous(),
                                            points.contiguous(),
                                            box_idxs_of_pts)

    return box_idxs_of_pts


def points_in_boxes_cpu(points, boxes):
wuyuefeng's avatar
wuyuefeng committed
32
33
    """find points in boxes (CPU)

wuyuefeng's avatar
wuyuefeng committed
34
35
36
37
    Args:
        points (torch.Tensor): [npoints, 3]
        boxes (torch.Tensor): [N, 7], in LiDAR coordinate,
            (x, y, z) is the bottom center
wuyuefeng's avatar
wuyuefeng committed
38

wuyuefeng's avatar
wuyuefeng committed
39
40
41
42
43
44
45
46
47
48
49
50
51
    Returns:
        point_indices (torch.Tensor): (N, npoints)
    """
    assert boxes.shape[1] == 7
    assert points.shape[1] == 3

    point_indices = points.new_zeros((boxes.shape[0], points.shape[0]),
                                     dtype=torch.int)
    roiaware_pool3d_ext.points_in_boxes_cpu(boxes.float().contiguous(),
                                            points.float().contiguous(),
                                            point_indices)

    return point_indices