base_3droi_head.py 1.95 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from mmdet.models.roi_heads import BaseRoIHead
3

4
5
from mmdet3d.registry import MODELS, TASK_UTILS

wuyuefeng's avatar
wuyuefeng committed
6

7
class Base3DRoIHead(BaseRoIHead):
zhangwenwei's avatar
zhangwenwei committed
8
    """Base class for 3d RoIHeads."""
wuyuefeng's avatar
wuyuefeng committed
9
10
11

    def __init__(self,
                 bbox_head=None,
12
                 bbox_roi_extractor=None,
wuyuefeng's avatar
wuyuefeng committed
13
                 mask_head=None,
14
                 mask_roi_extractor=None,
wuyuefeng's avatar
wuyuefeng committed
15
                 train_cfg=None,
16
17
                 test_cfg=None,
                 init_cfg=None):
18
19
20
21
22
23
24
25
26
27
28
29
        super(Base3DRoIHead, self).__init__(
            bbox_head=bbox_head,
            bbox_roi_extractor=bbox_roi_extractor,
            mask_head=mask_head,
            mask_roi_extractor=mask_roi_extractor,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)

    def init_bbox_head(self, bbox_roi_extractor: dict,
                       bbox_head: dict) -> None:
        """Initialize box head and box roi extractor.
wuyuefeng's avatar
wuyuefeng committed
30
31

        Args:
32
33
34
            bbox_roi_extractor (dict or ConfigDict): Config of box
                roi extractor.
            bbox_head (dict or ConfigDict): Config of box in box head.
wuyuefeng's avatar
wuyuefeng committed
35
        """
36
37
        self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor)
        self.bbox_head = MODELS.build(bbox_head)
wuyuefeng's avatar
wuyuefeng committed
38

39
40
41
42
43
44
45
46
47
48
49
50
    def init_assigner_sampler(self):
        """Initialize assigner and sampler."""
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    TASK_UTILS.build(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
wuyuefeng's avatar
wuyuefeng committed
51

52
53
54
    def init_mask_head(self):
        """Initialize mask head, skip since ``PartAggregationROIHead`` does not
        have one."""
wuyuefeng's avatar
wuyuefeng committed
55
        pass