base_points.py 19.4 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import warnings
3
from abc import abstractmethod
4
from typing import Iterator, Optional, Sequence, Union
5

6
7
import numpy as np
import torch
8
from torch import Tensor
9

10
from mmdet3d.structures.bbox_3d.utils import rotation_3d_in_axis
11

12

13
class BasePoints:
14
15
16
    """Base class for Points.

    Args:
17
18
19
20
21
22
        tensor (Tensor or np.ndarray or Sequence[Sequence[float]]): The points
            data with shape (N, points_dim).
        points_dim (int): Integer indicating the dimension of a point. Each row
            is (x, y, z, ...). Defaults to 3.
        attribute_dims (dict, optional): Dictionary to indicate the meaning of
            extra dimension. Defaults to None.
23
24

    Attributes:
25
26
27
28
29
        tensor (Tensor): Float matrix with shape (N, points_dim).
        points_dim (int): Integer indicating the dimension of a point. Each row
            is (x, y, z, ...).
        attribute_dims (dict, optional): Dictionary to indicate the meaning of
            extra dimension. Defaults to None.
30
        rotation_axis (int): Default rotation axis for points rotation.
31
32
    """

33
34
35
36
37
    def __init__(self,
                 tensor: Union[Tensor, np.ndarray, Sequence[Sequence[float]]],
                 points_dim: int = 3,
                 attribute_dims: Optional[dict] = None) -> None:
        if isinstance(tensor, Tensor):
38
39
40
41
42
            device = tensor.device
        else:
            device = torch.device('cpu')
        tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
        if tensor.numel() == 0:
43
44
45
46
47
48
49
50
51
            # Use reshape, so we don't end up creating a new tensor that does
            # not depend on the inputs (and consequently confuses jit)
            tensor = tensor.reshape((-1, points_dim))
        assert tensor.dim() == 2 and tensor.size(-1) == points_dim, \
            ('The points dimension must be 2 and the length of the last '
             f'dimension must be {points_dim}, but got points with shape '
             f'{tensor.shape}.')

        self.tensor = tensor.clone()
52
53
        self.points_dim = points_dim
        self.attribute_dims = attribute_dims
54
        self.rotation_axis = 0
55
56

    @property
57
58
    def coord(self) -> Tensor:
        """Tensor: Coordinates of each point in shape (N, 3)."""
59
60
        return self.tensor[:, :3]

61
    @coord.setter
62
63
64
65
66
67
68
    def coord(self, tensor: Union[Tensor, np.ndarray]) -> None:
        """Set the coordinates of each point.

        Args:
            tensor (Tensor or np.ndarray): Coordinates of each point with shape
                (N, 3).
        """
69
70
71
72
        try:
            tensor = tensor.reshape(self.shape[0], 3)
        except (RuntimeError, ValueError):  # for torch.Tensor and np.ndarray
            raise ValueError(f'got unexpected shape {tensor.shape}')
73
        if not isinstance(tensor, Tensor):
74
75
76
            tensor = self.tensor.new_tensor(tensor)
        self.tensor[:, :3] = tensor

77
    @property
78
79
80
    def height(self) -> Union[Tensor, None]:
        """Tensor or None: Returns a vector with height of each point in shape
        (N, )."""
81
82
83
84
85
86
        if self.attribute_dims is not None and \
                'height' in self.attribute_dims.keys():
            return self.tensor[:, self.attribute_dims['height']]
        else:
            return None

87
    @height.setter
88
89
90
91
92
93
94
    def height(self, tensor: Union[Tensor, np.ndarray]) -> None:
        """Set the height of each point.

        Args:
            tensor (Tensor or np.ndarray): Height of each point with shape
                (N, ).
        """
95
96
97
98
        try:
            tensor = tensor.reshape(self.shape[0])
        except (RuntimeError, ValueError):  # for torch.Tensor and np.ndarray
            raise ValueError(f'got unexpected shape {tensor.shape}')
99
        if not isinstance(tensor, Tensor):
100
101
102
103
104
105
106
107
108
109
110
111
112
            tensor = self.tensor.new_tensor(tensor)
        if self.attribute_dims is not None and \
                'height' in self.attribute_dims.keys():
            self.tensor[:, self.attribute_dims['height']] = tensor
        else:
            # add height attribute
            if self.attribute_dims is None:
                self.attribute_dims = dict()
            attr_dim = self.shape[1]
            self.tensor = torch.cat([self.tensor, tensor.unsqueeze(1)], dim=1)
            self.attribute_dims.update(dict(height=attr_dim))
            self.points_dim += 1

113
    @property
114
115
116
    def color(self) -> Union[Tensor, None]:
        """Tensor or None: Returns a vector with color of each point in shape
        (N, 3)."""
117
118
119
120
121
122
        if self.attribute_dims is not None and \
                'color' in self.attribute_dims.keys():
            return self.tensor[:, self.attribute_dims['color']]
        else:
            return None

123
    @color.setter
124
125
126
127
128
129
130
    def color(self, tensor: Union[Tensor, np.ndarray]) -> None:
        """Set the color of each point.

        Args:
            tensor (Tensor or np.ndarray): Color of each point with shape
                (N, 3).
        """
131
132
133
134
135
136
        try:
            tensor = tensor.reshape(self.shape[0], 3)
        except (RuntimeError, ValueError):  # for torch.Tensor and np.ndarray
            raise ValueError(f'got unexpected shape {tensor.shape}')
        if tensor.max() >= 256 or tensor.min() < 0:
            warnings.warn('point got color value beyond [0, 255]')
137
        if not isinstance(tensor, Tensor):
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            tensor = self.tensor.new_tensor(tensor)
        if self.attribute_dims is not None and \
                'color' in self.attribute_dims.keys():
            self.tensor[:, self.attribute_dims['color']] = tensor
        else:
            # add color attribute
            if self.attribute_dims is None:
                self.attribute_dims = dict()
            attr_dim = self.shape[1]
            self.tensor = torch.cat([self.tensor, tensor], dim=1)
            self.attribute_dims.update(
                dict(color=[attr_dim, attr_dim + 1, attr_dim + 2]))
            self.points_dim += 3

152
    @property
153
154
    def shape(self) -> torch.Size:
        """torch.Size: Shape of points."""
155
156
        return self.tensor.shape

157
    def shuffle(self) -> Tensor:
158
159
160
        """Shuffle the points.

        Returns:
161
            Tensor: The shuffled index.
162
163
164
165
        """
        idx = torch.randperm(self.__len__(), device=self.tensor.device)
        self.tensor = self.tensor[idx]
        return idx
166

167
168
169
    def rotate(self,
               rotation: Union[Tensor, np.ndarray, float],
               axis: Optional[int] = None) -> Tensor:
170
171
172
        """Rotate points with the given rotation matrix or angle.

        Args:
173
            rotation (Tensor or np.ndarray or float): Rotation matrix or angle.
174
            axis (int, optional): Axis to rotate at. Defaults to None.
175
176
177

        Returns:
            Tensor: Rotation matrix.
178
        """
179
        if not isinstance(rotation, Tensor):
180
            rotation = self.tensor.new_tensor(rotation)
181
182
        assert rotation.shape == torch.Size([3, 3]) or rotation.numel() == 1, \
            f'invalid rotation shape {rotation.shape}'
183

184
185
186
        if axis is None:
            axis = self.rotation_axis

187
        if rotation.numel() == 1:
188
189
190
191
            rotated_points, rot_mat_T = rotation_3d_in_axis(
                self.tensor[:, :3][None], rotation, axis=axis, return_mat=True)
            self.tensor[:, :3] = rotated_points.squeeze(0)
            rot_mat_T = rot_mat_T.squeeze(0)
192
        else:
193
194
195
            # rotation.numel() == 9
            self.tensor[:, :3] = self.tensor[:, :3] @ rotation
            rot_mat_T = rotation
196

197
198
        return rot_mat_T

199
    @abstractmethod
200
    def flip(self, bev_direction: str = 'horizontal') -> None:
201
202
203
204
        """Flip the points along given BEV direction.

        Args:
            bev_direction (str): Flip direction (horizontal or vertical).
205
                Defaults to 'horizontal'.
206
        """
207
208
        pass

209
    def translate(self, trans_vector: Union[Tensor, np.ndarray]) -> None:
210
211
212
        """Translate points with the given translation vector.

        Args:
213
214
            trans_vector (Tensor or np.ndarray): Translation vector of size 3
                or nx3.
215
        """
216
        if not isinstance(trans_vector, Tensor):
217
218
219
220
221
222
223
224
225
            trans_vector = self.tensor.new_tensor(trans_vector)
        trans_vector = trans_vector.squeeze(0)
        if trans_vector.dim() == 1:
            assert trans_vector.shape[0] == 3
        elif trans_vector.dim() == 2:
            assert trans_vector.shape[0] == self.tensor.shape[0] and \
                trans_vector.shape[1] == 3
        else:
            raise NotImplementedError(
226
227
                f'Unsupported translation vector of shape {trans_vector.shape}'
            )
228
229
        self.tensor[:, :3] += trans_vector

230
231
232
    def in_range_3d(
            self, point_range: Union[Tensor, np.ndarray,
                                     Sequence[float]]) -> Tensor:
233
234
235
        """Check whether the points are in the given range.

        Args:
236
237
            point_range (Tensor or np.ndarray or Sequence[float]): The range of
                point (x_min, y_min, z_min, x_max, y_max, z_max).
238
239

        Note:
240
241
242
            In the original implementation of SECOND, checking whether a box in
            the range checks whether the points are in a convex polygon, we try
            to reduce the burden for simpler cases.
243
244

        Returns:
245
246
            Tensor: A binary vector indicating whether each point is inside the
            reference range.
247
248
249
250
251
252
253
254
255
        """
        in_range_flags = ((self.tensor[:, 0] > point_range[0])
                          & (self.tensor[:, 1] > point_range[1])
                          & (self.tensor[:, 2] > point_range[2])
                          & (self.tensor[:, 0] < point_range[3])
                          & (self.tensor[:, 1] < point_range[4])
                          & (self.tensor[:, 2] < point_range[5]))
        return in_range_flags

256
    @property
257
258
    def bev(self) -> Tensor:
        """Tensor: BEV of the points in shape (N, 2)."""
259
260
        return self.tensor[:, [0, 1]]

261
262
263
    def in_range_bev(
            self, point_range: Union[Tensor, np.ndarray,
                                     Sequence[float]]) -> Tensor:
264
265
266
        """Check whether the points are in the given range.

        Args:
267
268
            point_range (Tensor or np.ndarray or Sequence[float]): The range of
                point in order of (x_min, y_min, x_max, y_max).
269
270

        Returns:
271
272
            Tensor: A binary vector indicating whether each point is inside the
            reference range.
273
        """
274
275
        in_range_flags = ((self.bev[:, 0] > point_range[0])
                          & (self.bev[:, 1] > point_range[1])
276
                          & (self.bev[:, 0] < point_range[2])
277
278
                          & (self.bev[:, 1] < point_range[3]))
        return in_range_flags
279
280

    @abstractmethod
281
282
283
284
    def convert_to(self,
                   dst: int,
                   rt_mat: Optional[Union[Tensor,
                                          np.ndarray]] = None) -> 'BasePoints':
285
286
287
        """Convert self to ``dst`` mode.

        Args:
288
289
            dst (int): The target Point mode.
            rt_mat (Tensor or np.ndarray, optional): The rotation and
290
                translation matrix between different coordinates.
291
292
293
294
                Defaults to None. The conversion from ``src`` coordinates to
                ``dst`` coordinates usually comes along the change of sensors,
                e.g., from camera to LiDAR. This requires a transformation
                matrix.
295
296

        Returns:
297
298
            :obj:`BasePoints`: The converted point of the same type in the
            ``dst`` mode.
299
300
301
        """
        pass

302
    def scale(self, scale_factor: float) -> None:
303
304
305
306
307
308
309
        """Scale the points with horizontal and vertical scaling factors.

        Args:
            scale_factors (float): Scale factors to scale the points.
        """
        self.tensor[:, :3] *= scale_factor

310
311
312
    def __getitem__(
            self, item: Union[int, tuple, slice, np.ndarray,
                              Tensor]) -> 'BasePoints':
313
        """
314
315
316
317
        Args:
            item (int or tuple or slice or np.ndarray or Tensor): Index of
                points.

318
319
        Note:
            The following usage are allowed:
320
321
322
323
324
325
326
327
328
329
330
331

            1. `new_points = points[3]`: Return a `Points` that contains only
               one point.
            2. `new_points = points[2:10]`: Return a slice of points.
            3. `new_points = points[vector]`: Whether vector is a
               torch.BoolTensor with `length = len(points)`. Nonzero elements
               in the vector will be selected.
            4. `new_points = points[3:11, vector]`: Return a slice of points
               and attribute dims.
            5. `new_points = points[4:12, 2]`: Return a slice of points with
               single attribute.

332
            Note that the returned Points might share storage with this Points,
333
            subject to PyTorch's indexing semantics.
334
335

        Returns:
336
337
            :obj:`BasePoints`: A new object of :class:`BasePoints` after
            indexing.
338
339
340
341
342
343
344
        """
        original_type = type(self)
        if isinstance(item, int):
            return original_type(
                self.tensor[item].view(1, -1),
                points_dim=self.points_dim,
                attribute_dims=self.attribute_dims)
345
346
347
        elif isinstance(item, tuple) and len(item) == 2:
            if isinstance(item[1], slice):
                start = 0 if item[1].start is None else item[1].start
348
349
                stop = self.tensor.shape[1] \
                    if item[1].stop is None else item[1].stop
350
                step = 1 if item[1].step is None else item[1].step
meng-zha's avatar
meng-zha committed
351
                item = list(item)
352
                item[1] = list(range(start, stop, step))
meng-zha's avatar
meng-zha committed
353
                item = tuple(item)
354
355
356
357
            elif isinstance(item[1], int):
                item = list(item)
                item[1] = [item[1]]
                item = tuple(item)
358
359
360
361
362
363
364
            p = self.tensor[item[0], item[1]]

            keep_dims = list(
                set(item[1]).intersection(set(range(3, self.tensor.shape[1]))))
            if self.attribute_dims is not None:
                attribute_dims = self.attribute_dims.copy()
                for key in self.attribute_dims.keys():
365
366
367
                    cur_attribute_dims = attribute_dims[key]
                    if isinstance(cur_attribute_dims, int):
                        cur_attribute_dims = [cur_attribute_dims]
368
369
370
371
372
373
374
375
376
377
                    intersect_attr = list(
                        set(cur_attribute_dims).intersection(set(keep_dims)))
                    if len(intersect_attr) == 1:
                        attribute_dims[key] = intersect_attr[0]
                    elif len(intersect_attr) > 1:
                        attribute_dims[key] = intersect_attr
                    else:
                        attribute_dims.pop(key)
            else:
                attribute_dims = None
378
        elif isinstance(item, (slice, np.ndarray, Tensor)):
379
380
381
382
383
            p = self.tensor[item]
            attribute_dims = self.attribute_dims
        else:
            raise NotImplementedError(f'Invalid slice {item}!')

384
385
386
        assert p.dim() == 2, \
            f'Indexing on Points with {item} failed to return a matrix!'
        return original_type(
387
            p, points_dim=p.shape[1], attribute_dims=attribute_dims)
388

389
    def __len__(self) -> int:
390
391
392
        """int: Number of points in the current object."""
        return self.tensor.shape[0]

393
394
    def __repr__(self) -> str:
        """str: Return a string that describes the object."""
395
396
397
        return self.__class__.__name__ + '(\n    ' + str(self.tensor) + ')'

    @classmethod
398
    def cat(cls, points_list: Sequence['BasePoints']) -> 'BasePoints':
399
400
401
        """Concatenate a list of Points into a single Points.

        Args:
402
            points_list (Sequence[:obj:`BasePoints`]): List of points.
403
404

        Returns:
405
            :obj:`BasePoints`: The concatenated points.
406
407
408
409
410
411
412
413
414
415
        """
        assert isinstance(points_list, (list, tuple))
        if len(points_list) == 0:
            return cls(torch.empty(0))
        assert all(isinstance(points, cls) for points in points_list)

        # use torch.cat (v.s. layers.cat)
        # so the returned points never share storage with input
        cat_points = cls(
            torch.cat([p.tensor for p in points_list], dim=0),
416
            points_dim=points_list[0].points_dim,
417
418
419
            attribute_dims=points_list[0].attribute_dims)
        return cat_points

420
421
422
423
    def numpy(self) -> np.ndarray:
        """Reload ``numpy`` from self.tensor."""
        return self.tensor.numpy()

424
425
    def to(self, device: Union[str, torch.device], *args,
           **kwargs) -> 'BasePoints':
426
427
428
        """Convert current points to a specific device.

        Args:
429
            device (str or :obj:`torch.device`): The name of the device.
430
431

        Returns:
432
            :obj:`BasePoints`: A new points object on the specific device.
433
434
435
        """
        original_type = type(self)
        return original_type(
436
            self.tensor.to(device, *args, **kwargs),
437
438
439
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    def cpu(self) -> 'BasePoints':
        """Convert current points to cpu device.

        Returns:
            :obj:`BasePoints`: A new points object on the cpu device.
        """
        original_type = type(self)
        return original_type(
            self.tensor.cpu(),
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)

    def cuda(self, *args, **kwargs) -> 'BasePoints':
        """Convert current points to cuda device.

        Returns:
            :obj:`BasePoints`: A new points object on the cuda device.
        """
        original_type = type(self)
        return original_type(
            self.tensor.cuda(*args, **kwargs),
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)

464
465
    def clone(self) -> 'BasePoints':
        """Clone the points.
466
467

        Returns:
468
            :obj:`BasePoints`: Point object with the same properties as self.
469
470
471
472
473
474
475
        """
        original_type = type(self)
        return original_type(
            self.tensor.clone(),
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)

476
477
478
479
480
481
482
483
484
485
486
487
    def detach(self) -> 'BasePoints':
        """Detach the points.

        Returns:
            :obj:`BasePoints`: Point object with the same properties as self.
        """
        original_type = type(self)
        return original_type(
            self.tensor.detach(),
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)

488
    @property
489
490
    def device(self) -> torch.device:
        """torch.device: The device of the points are on."""
491
492
        return self.tensor.device

493
494
    def __iter__(self) -> Iterator[Tensor]:
        """Yield a point as a Tensor at a time.
495
496

        Returns:
497
            Iterator[Tensor]: A point of shape (points_dim, ).
498
499
500
        """
        yield from self.tensor

501
502
503
    def new_point(
        self, data: Union[Tensor, np.ndarray, Sequence[Sequence[float]]]
    ) -> 'BasePoints':
504
505
        """Create a new point object with data.

506
507
        The new point and its tensor has the similar properties as self and
        self.tensor, respectively.
508
509

        Args:
510
511
            data (Tensor or np.ndarray or Sequence[Sequence[float]]): Data to
                be copied.
512
513

        Returns:
514
515
            :obj:`BasePoints`: A new point object with ``data``, the object's
            other properties are similar to ``self``.
516
517
        """
        new_tensor = self.tensor.new_tensor(data) \
518
            if not isinstance(data, Tensor) else data.to(self.device)
519
520
521
522
523
        original_type = type(self)
        return original_type(
            new_tensor,
            points_dim=self.points_dim,
            attribute_dims=self.attribute_dims)