parta2.py 2.5 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Dict, Optional
wuyuefeng's avatar
wuyuefeng committed
3

4
from mmdet3d.registry import MODELS
zhangwenwei's avatar
zhangwenwei committed
5
from .two_stage import TwoStage3DDetector
wuyuefeng's avatar
wuyuefeng committed
6
7


8
@MODELS.register_module()
zhangwenwei's avatar
zhangwenwei committed
9
class PartA2(TwoStage3DDetector):
10
    r"""Part-A2 detector.
zhangwenwei's avatar
zhangwenwei committed
11
12
13

    Please refer to the `paper <https://arxiv.org/abs/1907.03670>`_
    """
wuyuefeng's avatar
wuyuefeng committed
14
15

    def __init__(self,
16
17
18
19
20
21
22
23
24
25
                 voxel_encoder: dict,
                 middle_encoder: dict,
                 backbone: dict,
                 neck: dict = None,
                 rpn_head: dict = None,
                 roi_head: dict = None,
                 train_cfg: dict = None,
                 test_cfg: dict = None,
                 init_cfg: dict = None,
                 data_preprocessor: Optional[dict] = None):
wuyuefeng's avatar
wuyuefeng committed
26
27
28
29
30
31
32
        super(PartA2, self).__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
33
34
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor)
35
36
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
wuyuefeng's avatar
wuyuefeng committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def extract_feat(self, batch_inputs_dict: Dict) -> Dict:
        """Directly extract features from the backbone+neck.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (torch.Tensor, optional): Image of each sample.

        Returns:
            tuple[Tensor] | dict:  For outside 3D object detection, we
                typically obtain a tuple of features from the backbone + neck,
                and for inside 3D object detection, usually a dict containing
                features will be obtained.
        """
54
        voxel_dict = batch_inputs_dict['voxels']
wuyuefeng's avatar
wuyuefeng committed
55
56
57
58
59
60
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        feats_dict = self.middle_encoder(voxel_features, voxel_dict['coors'],
                                         batch_size)
wuyuefeng's avatar
wuyuefeng committed
61
62
63
64
        x = self.backbone(feats_dict['spatial_features'])
        if self.with_neck:
            neck_feats = self.neck(x)
            feats_dict.update({'neck_feats': neck_feats})
65
66
        feats_dict['voxels_dict'] = voxel_dict
        return feats_dict