mvx_faster_rcnn.py 1.95 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Dict, List, Optional, Sequence
jshilong's avatar
jshilong committed
3
4

from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
5

6
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
7
8
9
from .mvx_two_stage import MVXTwoStageDetector


10
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
11
class MVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
12
    """Multi-modality VoxelNet using Faster R-CNN."""
zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
17

    def __init__(self, **kwargs):
        super(MVXFasterRCNN, self).__init__(**kwargs)


18
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
19
class DynamicMVXFasterRCNN(MVXTwoStageDetector):
zhangwenwei's avatar
zhangwenwei committed
20
    """Multi-modality VoxelNet using Faster R-CNN and dynamic voxelization."""
zhangwenwei's avatar
zhangwenwei committed
21
22
23
24

    def __init__(self, **kwargs):
        super(DynamicMVXFasterRCNN, self).__init__(**kwargs)

jshilong's avatar
jshilong committed
25
26
    def extract_pts_feat(
            self,
27
28
            voxel_dict: Dict[str, Tensor],
            points: Optional[List[Tensor]] = None,
jshilong's avatar
jshilong committed
29
30
31
32
33
34
            img_feats: Optional[Sequence[Tensor]] = None,
            batch_input_metas: Optional[List[dict]] = None
    ) -> Sequence[Tensor]:
        """Extract features of points.

        Args:
35
36
            voxel_dict(Dict[str, Tensor]): Dict of voxelization infos.
            points (List[tensor], optional):  Point cloud of multiple inputs.
jshilong's avatar
jshilong committed
37
38
39
40
41
42
43
44
45
            img_feats (list[Tensor], tuple[tensor], optional): Features from
                image backbone.
            batch_input_metas (list[dict], optional): The meta information
                of multiple samples. Defaults to True.

        Returns:
            Sequence[tensor]: points features of multiple inputs
            from backbone or neck.
        """
zhangwenwei's avatar
zhangwenwei committed
46
47
48
        if not self.with_pts_bbox:
            return None
        voxel_features, feature_coors = self.pts_voxel_encoder(
49
50
51
            voxel_dict['voxels'], voxel_dict['coors'], points, img_feats,
            batch_input_metas)
        batch_size = voxel_dict['coors'][-1, 0] + 1
zhangwenwei's avatar
zhangwenwei committed
52
53
54
55
56
        x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
        x = self.pts_backbone(x)
        if self.with_pts_neck:
            x = self.pts_neck(x)
        return x