base_box3d.py 20.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
zhangwenwei's avatar
zhangwenwei committed
3
from abc import abstractmethod
4

5
6
import numpy as np
import torch
7
from mmcv.ops import box_iou_rotated, points_in_boxes_all, points_in_boxes_part
8

9
from .utils import limit_period
zhangwenwei's avatar
zhangwenwei committed
10

11
12

class BaseInstance3DBoxes(object):
zhangwenwei's avatar
zhangwenwei committed
13
    """Base class for 3D Boxes.
14

zhangwenwei's avatar
zhangwenwei committed
15
16
    Note:
        The box is bottom centered, i.e. the relative position of origin in
zhangwenwei's avatar
zhangwenwei committed
17
        the box is (0.5, 0.5, 0).
zhangwenwei's avatar
zhangwenwei committed
18

zhangwenwei's avatar
zhangwenwei committed
19
    Args:
20
        tensor (torch.Tensor | np.ndarray | list): a N x box_dim matrix.
21
        box_dim (int): Number of the dimension of a box.
zhangwenwei's avatar
zhangwenwei committed
22
            Each row is (x, y, z, x_size, y_size, z_size, yaw).
23
            Defaults to 7.
zhangwenwei's avatar
zhangwenwei committed
24
25
        with_yaw (bool): Whether the box is with yaw rotation.
            If False, the value of yaw will be set to 0 as minmax boxes.
26
27
28
            Defaults to True.
        origin (tuple[float], optional): Relative position of the box origin.
            Defaults to (0.5, 0.5, 0). This will guide the box be converted to
wuyuefeng's avatar
wuyuefeng committed
29
            (0.5, 0.5, 0) mode.
Wenwei Zhang's avatar
Wenwei Zhang committed
30
31
32
33
34
35
36

    Attributes:
        tensor (torch.Tensor): Float matrix of N x box_dim.
        box_dim (int): Integer indicating the dimension of a box.
            Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
        with_yaw (bool): If True, the value of yaw will be set to 0 as minmax
            boxes.
37
38
    """

wuyuefeng's avatar
wuyuefeng committed
39
    def __init__(self, tensor, box_dim=7, with_yaw=True, origin=(0.5, 0.5, 0)):
40
41
42
43
44
45
46
47
48
49
50
        if isinstance(tensor, torch.Tensor):
            device = tensor.device
        else:
            device = torch.device('cpu')
        tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
        if tensor.numel() == 0:
            # Use reshape, so we don't end up creating a new tensor that
            # does not depend on the inputs (and consequently confuses jit)
            tensor = tensor.reshape((0, box_dim)).to(
                dtype=torch.float32, device=device)
        assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size()
wuyuefeng's avatar
wuyuefeng committed
51

wuyuefeng's avatar
wuyuefeng committed
52
53
54
        if tensor.shape[-1] == 6:
            # If the dimension of boxes is 6, we expand box_dim by padding
            # 0 as a fake yaw and set with_yaw to False.
wuyuefeng's avatar
wuyuefeng committed
55
56
57
58
            assert box_dim == 6
            fake_rot = tensor.new_zeros(tensor.shape[0], 1)
            tensor = torch.cat((tensor, fake_rot), dim=-1)
            self.box_dim = box_dim + 1
wuyuefeng's avatar
wuyuefeng committed
59
            self.with_yaw = False
wuyuefeng's avatar
wuyuefeng committed
60
61
        else:
            self.box_dim = box_dim
wuyuefeng's avatar
wuyuefeng committed
62
            self.with_yaw = with_yaw
63
        self.tensor = tensor.clone()
64

wuyuefeng's avatar
wuyuefeng committed
65
66
        if origin != (0.5, 0.5, 0):
            dst = self.tensor.new_tensor((0.5, 0.5, 0))
zhangwenwei's avatar
zhangwenwei committed
67
68
69
            src = self.tensor.new_tensor(origin)
            self.tensor[:, :3] += self.tensor[:, 3:6] * (dst - src)

zhangwenwei's avatar
zhangwenwei committed
70
    @property
71
    def volume(self):
Wenwei Zhang's avatar
Wenwei Zhang committed
72
        """torch.Tensor: A vector with volume of each box."""
73
74
        return self.tensor[:, 3] * self.tensor[:, 4] * self.tensor[:, 5]

zhangwenwei's avatar
zhangwenwei committed
75
76
    @property
    def dims(self):
77
        """torch.Tensor: Size dimensions of each box in shape (N, 3)."""
zhangwenwei's avatar
zhangwenwei committed
78
79
        return self.tensor[:, 3:6]

zhangwenwei's avatar
zhangwenwei committed
80
81
    @property
    def yaw(self):
82
        """torch.Tensor: A vector with yaw of each box in shape (N, )."""
zhangwenwei's avatar
zhangwenwei committed
83
84
        return self.tensor[:, 6]

85
86
    @property
    def height(self):
87
        """torch.Tensor: A vector with height of each box in shape (N, )."""
88
89
        return self.tensor[:, 5]

90
91
    @property
    def top_height(self):
92
93
        """torch.Tensor:
            A vector with the top height of each box in shape (N, )."""
94
95
96
97
        return self.bottom_height + self.height

    @property
    def bottom_height(self):
98
99
        """torch.Tensor:
            A vector with bottom's height of each box in shape (N, )."""
100
101
        return self.tensor[:, 2]

zhangwenwei's avatar
zhangwenwei committed
102
103
104
105
106
    @property
    def center(self):
        """Calculate the center of all the boxes.

        Note:
107
            In MMDetection3D's convention, the bottom center is
zhangwenwei's avatar
zhangwenwei committed
108
109
110
111
            usually taken as the default center.

            The relative position of the centers in different kinds of
            boxes are different, e.g., the relative center of a boxes is
wuyuefeng's avatar
wuyuefeng committed
112
            (0.5, 1.0, 0.5) in camera and (0.5, 0.5, 0) in lidar.
Wenwei Zhang's avatar
Wenwei Zhang committed
113
            It is recommended to use ``bottom_center`` or ``gravity_center``
114
            for clearer usage.
zhangwenwei's avatar
zhangwenwei committed
115
116

        Returns:
117
            torch.Tensor: A tensor with center of each box in shape (N, 3).
zhangwenwei's avatar
zhangwenwei committed
118
119
120
        """
        return self.bottom_center

zhangwenwei's avatar
zhangwenwei committed
121
    @property
122
    def bottom_center(self):
123
        """torch.Tensor: A tensor with center of each box in shape (N, 3)."""
zhangwenwei's avatar
zhangwenwei committed
124
        return self.tensor[:, :3]
125

zhangwenwei's avatar
zhangwenwei committed
126
    @property
127
    def gravity_center(self):
128
        """torch.Tensor: A tensor with center of each box in shape (N, 3)."""
129
130
        pass

zhangwenwei's avatar
zhangwenwei committed
131
    @property
132
    def corners(self):
133
134
        """torch.Tensor:
            a tensor with 8 corners of each box in shape (N, 8, 3)."""
135
136
        pass

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    @property
    def bev(self):
        """torch.Tensor: 2D BEV box of each box with rotation
            in XYWHR format, in shape (N, 5)."""
        return self.tensor[:, [0, 1, 3, 4, 6]]

    @property
    def nearest_bev(self):
        """torch.Tensor: A tensor of 2D BEV box of each box
            without rotation."""
        # Obtain BEV boxes with rotation in XYWHR format
        bev_rotated_boxes = self.bev
        # convert the rotation to a valid range
        rotations = bev_rotated_boxes[:, -1]
        normed_rotations = torch.abs(limit_period(rotations, 0.5, np.pi))

        # find the center of boxes
        conditions = (normed_rotations > np.pi / 4)[..., None]
        bboxes_xywh = torch.where(conditions, bev_rotated_boxes[:,
                                                                [0, 1, 3, 2]],
                                  bev_rotated_boxes[:, :4])

        centers = bboxes_xywh[:, :2]
        dims = bboxes_xywh[:, 2:]
        bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1)
        return bev_boxes

    def in_range_bev(self, box_range):
        """Check whether the boxes are in the given range.

        Args:
            box_range (list | torch.Tensor): the range of box
                (x_min, y_min, x_max, y_max)

        Note:
            The original implementation of SECOND checks whether boxes in
            a range by checking whether the points are in a convex
            polygon, we reduce the burden for simpler cases.

        Returns:
            torch.Tensor: Whether each box is inside the reference range.
        """
        in_range_flags = ((self.bev[:, 0] > box_range[0])
                          & (self.bev[:, 1] > box_range[1])
                          & (self.bev[:, 0] < box_range[2])
                          & (self.bev[:, 1] < box_range[3]))
        return in_range_flags

185
    @abstractmethod
186
    def rotate(self, angle, points=None):
187
188
        """Rotate boxes with points (optional) with the given angle or rotation
        matrix.
189
190

        Args:
191
192
            angle (float | torch.Tensor | np.ndarray):
                Rotation angle or rotation matrix.
193
194
            points (torch.Tensor | numpy.ndarray |
                :obj:`BasePoints`, optional):
195
                Points to rotate. Defaults to None.
196
197
198
199
        """
        pass

    @abstractmethod
wuyuefeng's avatar
wuyuefeng committed
200
    def flip(self, bev_direction='horizontal'):
201
202
203
204
205
206
207
        """Flip the boxes in BEV along given BEV direction.

        Args:
            bev_direction (str, optional): Direction by which to flip.
                Can be chosen from 'horizontal' and 'vertical'.
                Defaults to 'horizontal'.
        """
208
209
210
        pass

    def translate(self, trans_vector):
211
        """Translate boxes with the given translation vector.
212
213

        Args:
214
            trans_vector (torch.Tensor): Translation vector of size (1, 3).
215
        """
zhangwenwei's avatar
zhangwenwei committed
216
217
218
        if not isinstance(trans_vector, torch.Tensor):
            trans_vector = self.tensor.new_tensor(trans_vector)
        self.tensor[:, :3] += trans_vector
219

zhangwenwei's avatar
zhangwenwei committed
220
    def in_range_3d(self, box_range):
zhangwenwei's avatar
zhangwenwei committed
221
        """Check whether the boxes are in the given range.
222
223

        Args:
224
            box_range (list | torch.Tensor): The range of box
225
226
                (x_min, y_min, z_min, x_max, y_max, z_max)

zhangwenwei's avatar
zhangwenwei committed
227
228
229
        Note:
            In the original implementation of SECOND, checking whether
            a box in the range checks whether the points are in a convex
wangtai's avatar
wangtai committed
230
            polygon, we try to reduce the burden for simpler cases.
zhangwenwei's avatar
zhangwenwei committed
231

232
        Returns:
233
            torch.Tensor: A binary vector indicating whether each box is
zhangwenwei's avatar
zhangwenwei committed
234
                inside the reference range.
235
        """
zhangwenwei's avatar
zhangwenwei committed
236
237
238
239
240
241
242
        in_range_flags = ((self.tensor[:, 0] > box_range[0])
                          & (self.tensor[:, 1] > box_range[1])
                          & (self.tensor[:, 2] > box_range[2])
                          & (self.tensor[:, 0] < box_range[3])
                          & (self.tensor[:, 1] < box_range[4])
                          & (self.tensor[:, 2] < box_range[5]))
        return in_range_flags
243

244
245
    @abstractmethod
    def convert_to(self, dst, rt_mat=None):
Wenwei Zhang's avatar
Wenwei Zhang committed
246
        """Convert self to ``dst`` mode.
247
248

        Args:
249
            dst (:obj:`Box3DMode`): The target Box mode.
250
251
252
            rt_mat (np.ndarray | torch.Tensor, optional): The rotation and
                translation matrix between different coordinates.
                Defaults to None.
253
254
255
256
257
                The conversion from `src` coordinates to `dst` coordinates
                usually comes along the change of sensors, e.g., from camera
                to LiDAR. This requires a transformation matrix.

        Returns:
258
            :obj:`BaseInstance3DBoxes`: The converted box of the same type
Wenwei Zhang's avatar
Wenwei Zhang committed
259
                in the `dst` mode.
260
261
262
        """
        pass

zhangwenwei's avatar
zhangwenwei committed
263
    def scale(self, scale_factor):
zhangwenwei's avatar
zhangwenwei committed
264
        """Scale the box with horizontal and vertical scaling factors.
zhangwenwei's avatar
zhangwenwei committed
265
266

        Args:
liyinhao's avatar
liyinhao committed
267
            scale_factors (float): Scale factors to scale the boxes.
zhangwenwei's avatar
zhangwenwei committed
268
        """
zhangwenwei's avatar
zhangwenwei committed
269
        self.tensor[:, :6] *= scale_factor
270
        self.tensor[:, 7:] *= scale_factor  # velocity
zhangwenwei's avatar
zhangwenwei committed
271
272

    def limit_yaw(self, offset=0.5, period=np.pi):
zhangwenwei's avatar
zhangwenwei committed
273
        """Limit the yaw to a given period and offset.
zhangwenwei's avatar
zhangwenwei committed
274
275

        Args:
276
277
            offset (float, optional): The offset of the yaw. Defaults to 0.5.
            period (float, optional): The expected period. Defaults to np.pi.
zhangwenwei's avatar
zhangwenwei committed
278
279
        """
        self.tensor[:, 6] = limit_period(self.tensor[:, 6], offset, period)
zhangwenwei's avatar
zhangwenwei committed
280

281
    def nonempty(self, threshold=0.0):
282
283
284
285
286
        """Find boxes that are non-empty.

        A box is considered empty,
        if either of its side is no larger than threshold.

zhangwenwei's avatar
zhangwenwei committed
287
        Args:
288
289
            threshold (float, optional): The threshold of minimal sizes.
                Defaults to 0.0.
zhangwenwei's avatar
zhangwenwei committed
290

291
        Returns:
292
            torch.Tensor: A binary vector which represents whether each
wuyuefeng's avatar
wuyuefeng committed
293
                box is empty (False) or non-empty (True).
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        """
        box = self.tensor
        size_x = box[..., 3]
        size_y = box[..., 4]
        size_z = box[..., 5]
        keep = ((size_x > threshold)
                & (size_y > threshold) & (size_z > threshold))
        return keep

    def __getitem__(self, item):
        """
        Note:
            The following usage are allowed:
            1. `new_boxes = boxes[3]`:
                return a `Boxes` that contains only one box.
            2. `new_boxes = boxes[2:10]`:
                return a slice of boxes.
            3. `new_boxes = boxes[vector]`:
                where vector is a torch.BoolTensor with `length = len(boxes)`.
                Nonzero elements in the vector will be selected.
            Note that the returned Boxes might share storage with this Boxes,
            subject to Pytorch's indexing semantics.

        Returns:
318
319
            :obj:`BaseInstance3DBoxes`: A new object of
                :class:`BaseInstance3DBoxes` after indexing.
320
321
322
        """
        original_type = type(self)
        if isinstance(item, int):
wuyuefeng's avatar
wuyuefeng committed
323
324
325
326
            return original_type(
                self.tensor[item].view(1, -1),
                box_dim=self.box_dim,
                with_yaw=self.with_yaw)
327
328
329
        b = self.tensor[item]
        assert b.dim() == 2, \
            f'Indexing on Boxes with {item} failed to return a matrix!'
wuyuefeng's avatar
wuyuefeng committed
330
        return original_type(b, box_dim=self.box_dim, with_yaw=self.with_yaw)
331
332

    def __len__(self):
wangtai's avatar
wangtai committed
333
        """int: Number of boxes in the current object."""
334
335
336
        return self.tensor.shape[0]

    def __repr__(self):
wangtai's avatar
wangtai committed
337
        """str: Return a strings that describes the object."""
338
339
340
341
        return self.__class__.__name__ + '(\n    ' + str(self.tensor) + ')'

    @classmethod
    def cat(cls, boxes_list):
342
        """Concatenate a list of Boxes into a single Boxes.
343

liyinhao's avatar
liyinhao committed
344
        Args:
345
            boxes_list (list[:obj:`BaseInstance3DBoxes`]): List of boxes.
zhangwenwei's avatar
zhangwenwei committed
346

347
        Returns:
348
            :obj:`BaseInstance3DBoxes`: The concatenated Boxes.
349
350
351
352
353
354
355
356
        """
        assert isinstance(boxes_list, (list, tuple))
        if len(boxes_list) == 0:
            return cls(torch.empty(0))
        assert all(isinstance(box, cls) for box in boxes_list)

        # use torch.cat (v.s. layers.cat)
        # so the returned boxes never share storage with input
zhangwenwei's avatar
zhangwenwei committed
357
358
359
360
        cat_boxes = cls(
            torch.cat([b.tensor for b in boxes_list], dim=0),
            box_dim=boxes_list[0].tensor.shape[1],
            with_yaw=boxes_list[0].with_yaw)
361
362
363
        return cat_boxes

    def to(self, device):
zhangwenwei's avatar
zhangwenwei committed
364
        """Convert current boxes to a specific device.
zhangwenwei's avatar
zhangwenwei committed
365
366
367
368
369

        Args:
            device (str | :obj:`torch.device`): The name of the device.

        Returns:
370
            :obj:`BaseInstance3DBoxes`: A new boxes object on the
zhangwenwei's avatar
zhangwenwei committed
371
372
                specific device.
        """
373
        original_type = type(self)
wuyuefeng's avatar
wuyuefeng committed
374
375
376
377
        return original_type(
            self.tensor.to(device),
            box_dim=self.box_dim,
            with_yaw=self.with_yaw)
378
379
380
381
382

    def clone(self):
        """Clone the Boxes.

        Returns:
383
            :obj:`BaseInstance3DBoxes`: Box object with the same properties
384
                as self.
385
386
        """
        original_type = type(self)
wuyuefeng's avatar
wuyuefeng committed
387
388
        return original_type(
            self.tensor.clone(), box_dim=self.box_dim, with_yaw=self.with_yaw)
389
390
391

    @property
    def device(self):
392
        """str: The device of the boxes are on."""
393
394
395
        return self.tensor.device

    def __iter__(self):
wuyuefeng's avatar
wuyuefeng committed
396
397
398
        """Yield a box as a Tensor of shape (4,) at a time.

        Returns:
liyinhao's avatar
liyinhao committed
399
            torch.Tensor: A box of shape (4,).
400
401
        """
        yield from self.tensor
402
403

    @classmethod
404
    def height_overlaps(cls, boxes1, boxes2, mode='iou'):
zhangwenwei's avatar
zhangwenwei committed
405
        """Calculate height overlaps of two boxes.
406
407

        Note:
408
            This function calculates the height overlaps between boxes1 and
409
            boxes2,  boxes1 and boxes2 should be in the same type.
410
411

        Args:
412
413
            boxes1 (:obj:`BaseInstance3DBoxes`): Boxes 1 contain N boxes.
            boxes2 (:obj:`BaseInstance3DBoxes`): Boxes 2 contain M boxes.
414
            mode (str, optional): Mode of IoU calculation. Defaults to 'iou'.
415
416

        Returns:
liyinhao's avatar
liyinhao committed
417
            torch.Tensor: Calculated iou of boxes.
418
        """
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        assert isinstance(boxes1, BaseInstance3DBoxes)
        assert isinstance(boxes2, BaseInstance3DBoxes)
        assert type(boxes1) == type(boxes2), '"boxes1" and "boxes2" should' \
            f'be in the same type, got {type(boxes1)} and {type(boxes2)}.'

        boxes1_top_height = boxes1.top_height.view(-1, 1)
        boxes1_bottom_height = boxes1.bottom_height.view(-1, 1)
        boxes2_top_height = boxes2.top_height.view(1, -1)
        boxes2_bottom_height = boxes2.bottom_height.view(1, -1)

        heighest_of_bottom = torch.max(boxes1_bottom_height,
                                       boxes2_bottom_height)
        lowest_of_top = torch.min(boxes1_top_height, boxes2_top_height)
        overlaps_h = torch.clamp(lowest_of_top - heighest_of_bottom, min=0)
        return overlaps_h

    @classmethod
    def overlaps(cls, boxes1, boxes2, mode='iou'):
zhangwenwei's avatar
zhangwenwei committed
437
        """Calculate 3D overlaps of two boxes.
438
439

        Note:
Wenwei Zhang's avatar
Wenwei Zhang committed
440
441
            This function calculates the overlaps between ``boxes1`` and
            ``boxes2``, ``boxes1`` and ``boxes2`` should be in the same type.
442
443

        Args:
444
445
            boxes1 (:obj:`BaseInstance3DBoxes`): Boxes 1 contain N boxes.
            boxes2 (:obj:`BaseInstance3DBoxes`): Boxes 2 contain M boxes.
liyinhao's avatar
liyinhao committed
446
            mode (str, optional): Mode of iou calculation. Defaults to 'iou'.
447
448

        Returns:
449
            torch.Tensor: Calculated 3D overlaps of the boxes.
450
451
452
453
454
455
456
457
        """
        assert isinstance(boxes1, BaseInstance3DBoxes)
        assert isinstance(boxes2, BaseInstance3DBoxes)
        assert type(boxes1) == type(boxes2), '"boxes1" and "boxes2" should' \
            f'be in the same type, got {type(boxes1)} and {type(boxes2)}.'

        assert mode in ['iou', 'iof']

zhangwenwei's avatar
zhangwenwei committed
458
459
460
461
462
        rows = len(boxes1)
        cols = len(boxes2)
        if rows * cols == 0:
            return boxes1.tensor.new(rows, cols)

463
464
465
466
        # height overlap
        overlaps_h = cls.height_overlaps(boxes1, boxes2)

        # bev overlap
467
468
469
470
471
472
        iou2d = box_iou_rotated(boxes1.bev, boxes2.bev)
        areas1 = (boxes1.bev[:, 2] * boxes1.bev[:, 3]).unsqueeze(1).expand(
            rows, cols)
        areas2 = (boxes2.bev[:, 2] * boxes2.bev[:, 3]).unsqueeze(0).expand(
            rows, cols)
        overlaps_bev = iou2d * (areas1 + areas2) / (1 + iou2d)
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        # 3d overlaps
        overlaps_3d = overlaps_bev.to(boxes1.device) * overlaps_h

        volume1 = boxes1.volume.view(-1, 1)
        volume2 = boxes2.volume.view(1, -1)

        if mode == 'iou':
            # the clamp func is used to avoid division of 0
            iou3d = overlaps_3d / torch.clamp(
                volume1 + volume2 - overlaps_3d, min=1e-8)
        else:
            iou3d = overlaps_3d / torch.clamp(volume1, min=1e-8)

        return iou3d
wuyuefeng's avatar
wuyuefeng committed
488
489
490
491

    def new_box(self, data):
        """Create a new box object with data.

492
        The new box and its tensor has the similar properties
wuyuefeng's avatar
wuyuefeng committed
493
494
495
            as self and self.tensor, respectively.

        Args:
496
            data (torch.Tensor | numpy.array | list): Data to be copied.
wuyuefeng's avatar
wuyuefeng committed
497
498

        Returns:
499
            :obj:`BaseInstance3DBoxes`: A new bbox object with ``data``,
Wenwei Zhang's avatar
Wenwei Zhang committed
500
                the object's other properties are similar to ``self``.
wuyuefeng's avatar
wuyuefeng committed
501
        """
zhangwenwei's avatar
zhangwenwei committed
502
503
        new_tensor = self.tensor.new_tensor(data) \
            if not isinstance(data, torch.Tensor) else data.to(self.device)
wuyuefeng's avatar
wuyuefeng committed
504
505
506
        original_type = type(self)
        return original_type(
            new_tensor, box_dim=self.box_dim, with_yaw=self.with_yaw)
507

508
509
    def points_in_boxes_part(self, points, boxes_override=None):
        """Find the box in which each point is.
510
511

        Args:
512
513
514
515
            points (torch.Tensor): Points in shape (1, M, 3) or (M, 3),
                3 dimensions are (x, y, z) in LiDAR or depth coordinate.
            boxes_override (torch.Tensor, optional): Boxes to override
                `self.tensor`. Defaults to None.
516
517

        Returns:
518
519
520
521
522
523
524
            torch.Tensor: The index of the first box that each point
                is in, in shape (M, ). Default value is -1
                (if the point is not enclosed by any box).

        Note:
            If a point is enclosed by multiple boxes, the index of the
            first box will be returned.
525
526
527
528
529
        """
        if boxes_override is not None:
            boxes = boxes_override
        else:
            boxes = self.tensor
530
531
532
533
534
        if points.dim() == 2:
            points = points.unsqueeze(0)
        box_idx = points_in_boxes_part(points,
                                       boxes.unsqueeze(0).to(
                                           points.device)).squeeze(0)
535
536
        return box_idx

537
538
    def points_in_boxes_all(self, points, boxes_override=None):
        """Find all boxes in which each point is.
539
540

        Args:
541
542
543
544
            points (torch.Tensor): Points in shape (1, M, 3) or (M, 3),
                3 dimensions are (x, y, z) in LiDAR or depth coordinate.
            boxes_override (torch.Tensor, optional): Boxes to override
                `self.tensor`. Defaults to None.
545
546

        Returns:
547
548
549
550
            torch.Tensor: A tensor indicating whether a point is in a box,
                in shape (M, T). T is the number of boxes. Denote this
                tensor as A, if the m^th point is in the t^th box, then
                `A[m, t] == 1`, elsewise `A[m, t] == 0`.
551
552
553
554
555
556
557
558
559
560
561
562
563
        """
        if boxes_override is not None:
            boxes = boxes_override
        else:
            boxes = self.tensor

        points_clone = points.clone()[..., :3]
        if points_clone.dim() == 2:
            points_clone = points_clone.unsqueeze(0)
        else:
            assert points_clone.dim() == 3 and points_clone.shape[0] == 1

        boxes = boxes.to(points_clone.device).unsqueeze(0)
564
        box_idxs_of_pts = points_in_boxes_all(points_clone, boxes)
565
566

        return box_idxs_of_pts.squeeze(0)
567
568
569
570
571
572
573
574
575
576
577
578

    def points_in_boxes(self, points, boxes_override=None):
        warnings.warn('DeprecationWarning: points_in_boxes is a '
                      'deprecated method, please consider using '
                      'points_in_boxes_part.')
        return self.points_in_boxes_part(points, boxes_override)

    def points_in_boxes_batch(self, points, boxes_override=None):
        warnings.warn('DeprecationWarning: points_in_boxes_batch is a '
                      'deprecated method, please consider using '
                      'points_in_boxes_all.')
        return self.points_in_boxes_all(points, boxes_override)