customize_models.md 21.1 KB
Newer Older
1
# 自定义模型
2

3
我们通常把模型的各个组成成分分成 6 种类型:
4

5
6
7
8
9
10
- 编码器(encoder):包括 voxel encoder 和 middle encoder 等进入 backbone 前所使用的基于体素的方法,如 `HardVFE``PointPillarsScatter`
- 骨干网络(backbone):通常采用 FCN 网络来提取特征图,如 `ResNet``SECOND`
- 颈部网络(neck):位于 backbones 和 heads 之间的组成模块,如 `FPN``SECONDFPN`
- 检测头(head):用于特定任务的组成模块,如`检测框的预测``掩码的预测`
- RoI 提取器(RoI extractor):用于从特征图中提取 RoI 特征的组成模块,如 `H3DRoIHead``PartAggregationROIHead`
- 损失函数(loss):heads 中用于计算损失函数的组成模块,如 `FocalLoss``L1Loss``GHMLoss`
11
12
13

## 开发新的组成模块

14
### 添加新的编码器
15
16
17

接下来我们以 HardVFE 为例展示如何开发新的组成模块。

18
#### 1. 定义一个新的体素编码器(如 HardVFE:即 HV-SECOND 中使用的体素特征编码器)
19

20
创建一个新文件 `mmdet3d/models/voxel_encoders/voxel_encoder.py`
21
22
23
24

```python
import torch.nn as nn

25
from mmdet3d.registry import MODELS
26
27


28
@MODELS.register_module()
29
30
31
32
33
class HardVFE(nn.Module):

    def __init__(self, arg1, arg2):
        pass

34
    def forward(self, x):  # 需要返回一个元组
35
36
37
        pass
```

38
#### 2. 导入该模块
39

40
您可以在 `mmdet3d/models/voxel_encoders/__init__.py` 中添加以下代码:
41
42
43
44
45

```python
from .voxel_encoder import HardVFE
```

46
或者在配置文件中添加以下代码,从而避免修改源码:
47
48
49

```python
custom_imports = dict(
50
    imports=['mmdet3d.models.voxel_encoders.voxel_encoder'],
51
52
53
    allow_failed_imports=False)
```

54
#### 3. 在配置文件中使用体素编码器
55
56
57
58
59
60
61

```python
model = dict(
    ...
    voxel_encoder=dict(
        type='HardVFE',
        arg1=xxx,
62
        arg2=yyy),
63
    ...
64
)
65
66
```

67
### 添加新的骨干网络
68

69
接下来我们以 [SECOND](https://www.mdpi.com/1424-8220/18/10/3337)(Sparsely Embedded Convolutional Detection)为例展示如何开发新的组成模块。
70

71
#### 1. 定义一个新的骨干网络(如 SECOND)
72

73
创建一个新文件 `mmdet3d/models/backbones/second.py`
74
75

```python
76
from mmengine.model import BaseModule
77

78
from mmdet3d.registry import MODELS
79
80


81
@MODELS.register_module()
82
83
84
85
86
class SECOND(BaseModule):

    def __init__(self, arg1, arg2):
        pass

87
    def forward(self, x):  # 需要返回一个元组
88
89
90
        pass
```

91
#### 2. 导入该模块
92

93
您可以在 `mmdet3d/models/backbones/__init__.py` 中添加以下代码:
94
95
96
97
98

```python
from .second import SECOND
```

99
或者在配置文件中添加以下代码,从而避免修改源码:
100
101
102
103
104
105
106

```python
custom_imports = dict(
    imports=['mmdet3d.models.backbones.second'],
    allow_failed_imports=False)
```

107
#### 3. 在配置文件中使用骨干网络
108
109
110
111
112
113
114

```python
model = dict(
    ...
    backbone=dict(
        type='SECOND',
        arg1=xxx,
115
        arg2=yyy),
116
    ...
117
)
118
119
```

120
### 添加新的颈部网络
121

122
#### 1. 定义一个新的颈部网络(如 SECONDFPN)
123

124
创建一个新文件 `mmdet3d/models/necks/second_fpn.py`
125
126

```python
127
128
from mmengine.model import BaseModule

129
from mmdet3d.registry import MODELS
130

131

132
@MODELS.register_module()
133
134
135
136
137
138
139
140
141
142
143
144
145
class SECONDFPN(BaseModule):

    def __init__(self,
                 in_channels=[128, 128, 256],
                 out_channels=[256, 256, 256],
                 upsample_strides=[1, 2, 4],
                 norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
                 upsample_cfg=dict(type='deconv', bias=False),
                 conv_cfg=dict(type='Conv2d', bias=False),
                 use_conv_for_no_stride=False,
                 init_cfg=None):
        pass

146
147
    def forward(self, x):
        # 具体实现忽略
148
149
150
        pass
```

151
#### 2. 导入该模块
152

153
您可以在 `mmdet3d/models/necks/__init__.py` 中添加以下代码:
154
155
156
157
158

```python
from .second_fpn import SECONDFPN
```

159
或者在配置文件中添加以下代码,从而避免修改源码:
160
161
162
163
164
165
166

```python
custom_imports = dict(
    imports=['mmdet3d.models.necks.second_fpn'],
    allow_failed_imports=False)
```

167
#### 3. 在配置文件中使用颈部网络
168
169
170
171
172
173
174
175
176
177

```python
model = dict(
    ...
    neck=dict(
        type='SECONDFPN',
        in_channels=[64, 128, 256],
        upsample_strides=[1, 2, 4],
        out_channels=[128, 128, 128]),
    ...
178
)
179
180
```

181
### 添加新的检测头
182

183
接下来我们以 [PartA2 Head](https://arxiv.org/abs/1907.03670) 为例展示如何开发新的检测头。
184

185
**注意**:此处展示的 `PartA2 RoI Head` 将用于检测器的第二阶段。对于单阶段的检测头,请参考 `mmdet3d/models/dense_heads/` 中的例子。由于其简单高效,它们更常用于自动驾驶场景下的 3D 检测中。
186

187
首先,在 `mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py` 中添加新的 bbox head。`PartA2 RoI Head` 为目标检测实现了一个新的 bbox head。为了实现一个 bbox head,我们通常需要在新模块中实现如下两个函数。有时还需要实现其他相关函数,如 `loss``get_targets`
188
189

```python
190
from mmengine.model import BaseModule
191

192
193
194
from mmdet3d.registry import MODELS


195
@MODELS.register_module()
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class PartA2BboxHead(BaseModule):
    """PartA2 RoI head."""

    def __init__(self,
                 num_classes,
                 seg_in_channels,
                 part_in_channels,
                 seg_conv_channels=None,
                 part_conv_channels=None,
                 merge_conv_channels=None,
                 down_conv_channels=None,
                 shared_fc_channels=None,
                 cls_channels=None,
                 reg_channels=None,
                 dropout_ratio=0.1,
                 roi_feat_size=14,
                 with_corner_loss=True,
                 bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 loss_bbox=dict(
                     type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
                 loss_cls=dict(
                     type='CrossEntropyLoss',
                     use_sigmoid=True,
                     reduction='none',
                     loss_weight=1.0),
                 init_cfg=None):
        super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)

    def forward(self, seg_feats, part_feats):
227
        pass
228
229
```

230
其次,如果有必要的话需要实现一个新的 RoI Head。我们从 `Base3DRoIHead` 中继承得到新的 `PartAggregationROIHead`。我们可以发现 `Base3DRoIHead` 已经实现了如下函数。
231
232

```python
233
from mmdet.models.roi_heads import BaseRoIHead
234

235
236
from mmdet3d.registry import MODELS, TASK_UTILS

237

238
class Base3DRoIHead(BaseRoIHead):
239
240
241
242
    """Base class for 3d RoIHeads."""

    def __init__(self,
                 bbox_head=None,
243
                 bbox_roi_extractor=None,
244
                 mask_head=None,
245
                 mask_roi_extractor=None,
246
247
248
                 train_cfg=None,
                 test_cfg=None,
                 init_cfg=None):
249
250
251
252
253
254
255
256
        super(Base3DRoIHead, self).__init__(
            bbox_head=bbox_head,
            bbox_roi_extractor=bbox_roi_extractor,
            mask_head=mask_head,
            mask_roi_extractor=mask_roi_extractor,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
257

258
259
260
    def init_bbox_head(self, bbox_roi_extractor: dict,
                       bbox_head: dict) -> None:
        """Initialize box head and box roi extractor.
261

262
263
264
265
266
267
268
        Args:
            bbox_roi_extractor (dict or ConfigDict): Config of box
                roi extractor.
            bbox_head (dict or ConfigDict): Config of box in box head.
        """
        self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor)
        self.bbox_head = MODELS.build(bbox_head)
269
270

    def init_assigner_sampler(self):
271
272
273
274
275
276
277
278
279
280
281
        """Initialize assigner and sampler."""
        self.bbox_assigner = None
        self.bbox_sampler = None
        if self.train_cfg:
            if isinstance(self.train_cfg.assigner, dict):
                self.bbox_assigner = TASK_UTILS.build(self.train_cfg.assigner)
            elif isinstance(self.train_cfg.assigner, list):
                self.bbox_assigner = [
                    TASK_UTILS.build(res) for res in self.train_cfg.assigner
                ]
            self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler)
282

283
284
285
    def init_mask_head(self):
        """Initialize mask head, skip since ``PartAggregationROIHead`` does not
        have one."""
286
287
288
        pass
```

289
接下来主要对 bbox_forward 的逻辑进行修改,同时其继承了来自 `Base3DRoIHead` 的其它逻辑。在 `mmdet3d/models/roi_heads/part_aggregation_roi_head.py` 中,我们实现了新的 RoI Head,如下所示:
290
291

```python
292
293
from typing import Dict, List, Tuple

294
295
from mmdet.models.task_modules import AssignResult, SamplingResult
from mmengine import ConfigDict
296
from torch import Tensor
297
298
from torch.nn import functional as F

299
300
301
302
from mmdet3d.registry import MODELS
from mmdet3d.structures import bbox3d2roi
from mmdet3d.utils import InstanceList
from ...structures.det3d_data_sample import SampleList
303
304
305
from .base_3droi_head import Base3DRoIHead


306
@MODELS.register_module()
307
308
class PartAggregationROIHead(Base3DRoIHead):
    """Part aggregation roi head for PartA2.
309

310
311
312
313
    Args:
        semantic_head (ConfigDict): Config of semantic head.
        num_classes (int): The number of classes.
        seg_roi_extractor (ConfigDict): Config of seg_roi_extractor.
314
        bbox_roi_extractor (ConfigDict): Config of part_roi_extractor.
315
316
317
318
319
320
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """

    def __init__(self,
321
322
323
324
325
326
327
328
                 semantic_head: dict,
                 num_classes: int = 3,
                 seg_roi_extractor: dict = None,
                 bbox_head: dict = None,
                 bbox_roi_extractor: dict = None,
                 train_cfg: dict = None,
                 test_cfg: dict = None,
                 init_cfg: dict = None) -> None:
329
330
        super(PartAggregationROIHead, self).__init__(
            bbox_head=bbox_head,
331
            bbox_roi_extractor=bbox_roi_extractor,
332
333
334
335
336
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
        self.num_classes = num_classes
        assert semantic_head is not None
337
        self.init_seg_head(seg_roi_extractor, semantic_head)
338

339
340
341
    def init_seg_head(self, seg_roi_extractor: dict,
                      semantic_head: dict) -> None:
        """Initialize semantic head and seg roi extractor.
342

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        Args:
            seg_roi_extractor (dict): Config of seg
                roi extractor.
            semantic_head (dict): Config of semantic head.
        """
        self.semantic_head = MODELS.build(semantic_head)
        self.seg_roi_extractor = MODELS.build(seg_roi_extractor)

    @property
    def with_semantic(self):
        """bool: whether the head has semantic branch"""
        return hasattr(self,
                       'semantic_head') and self.semantic_head is not None

    def predict(self,
                feats_dict: Dict,
                rpn_results_list: InstanceList,
                batch_data_samples: SampleList,
                rescale: bool = False,
                **kwargs) -> InstanceList:
        """Perform forward propagation of the roi head and predict detection
        results on the features of the upstream network.
365
366

        Args:
367
            feats_dict (dict): Contains features from the first stage.
368
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
369
370
371
372
373
374
375
                of rpn head.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.

376
        Returns:
377
378
379
380
381
382
383
384
385
386
387
            list[:obj:`InstanceData`]: Detection results of each sample
            after the post process.
            Each item usually contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
              contains a tensor with shape (num_instances, C), where
              C >= 7.
388
        """
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        assert self.with_bbox, 'Bbox head must be implemented in PartA2.'
        assert self.with_semantic, 'Semantic head must be implemented' \
                                   ' in PartA2.'

        batch_input_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        voxels_dict = feats_dict.pop('voxels_dict')
        # TODO: Split predict semantic and bbox
        results_list = self.predict_bbox(feats_dict, voxels_dict,
                                         batch_input_metas, rpn_results_list,
                                         self.test_cfg)
        return results_list

    def predict_bbox(self, feats_dict: Dict, voxel_dict: Dict,
                     batch_input_metas: List[dict],
                     rpn_results_list: InstanceList,
                     test_cfg: ConfigDict) -> InstanceList:
        """Perform forward propagation of the bbox head and predict detection
        results on the features of the upstream network.

        Args:
            feats_dict (dict): Contains features from the first stage.
            voxel_dict (dict): Contains information of voxels.
            batch_input_metas (list[dict], Optional): Batch image meta info.
                Defaults to None.
415
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
                of rpn head.
            test_cfg (Config): Test config.

        Returns:
            list[:obj:`InstanceData`]: Detection results of each sample
            after the post process.
            Each item usually contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instances, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes,
              contains a tensor with shape (num_instances, C), where
              C >= 7.
        """
        ...

    def loss(self, feats_dict: Dict, rpn_results_list: InstanceList,
             batch_data_samples: SampleList, **kwargs) -> dict:
        """Perform forward propagation and loss calculation of the detection
        roi on the features of the upstream network.

        Args:
            feats_dict (dict): Contains features from the first stage.
441
            rpn_results_list (List[:obj:`InstanceData`]): Detection results
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
                of rpn head.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.

        Returns:
            dict[str, Tensor]: A dictionary of loss components
        """
        assert len(rpn_results_list) == len(batch_data_samples)
        losses = dict()
        batch_gt_instances_3d = []
        batch_gt_instances_ignore = []
        voxels_dict = feats_dict.pop('voxels_dict')
        for data_sample in batch_data_samples:
            batch_gt_instances_3d.append(data_sample.gt_instances_3d)
            if 'ignored_instances' in data_sample:
                batch_gt_instances_ignore.append(data_sample.ignored_instances)
            else:
                batch_gt_instances_ignore.append(None)
        if self.with_semantic:
            semantic_results = self._semantic_forward_train(
                feats_dict, voxels_dict, batch_gt_instances_3d)
            losses.update(semantic_results.pop('loss_semantic'))

        sample_results = self._assign_and_sample(rpn_results_list,
                                                 batch_gt_instances_3d)
        if self.with_bbox:
            feats_dict.update(semantic_results)
            bbox_results = self._bbox_forward_train(feats_dict, voxels_dict,
                                                    sample_results)
            losses.update(bbox_results['loss_bbox'])

        return losses
475
476
```

477
此处我们省略了相关函数的更多细节。更多细节请参考[代码](https://github.com/open-mmlab/mmdetection3d/blob/dev-1.x/mmdet3d/models/roi_heads/part_aggregation_roi_head.py)
478

479
最后,用户需要在 `mmdet3d/models/roi_heads/bbox_heads/__init__.py``mmdet3d/models/roi_heads/__init__.py` 添加模块,从而能被相应的注册器找到并加载。
480

481
此外,用户也可以在配置文件中添加以下代码以达到相同的目的。
482
483
484

```python
custom_imports=dict(
485
486
    imports=['mmdet3d.models.roi_heads.part_aggregation_roi_head', 'mmdet3d.models.roi_heads.bbox_heads.parta2_bbox_head'],
    allow_failed_imports=False)
487
488
```

489
`PartAggregationROIHead` 的配置文件如下所示:
490
491
492
493
494
495
496
497
498
499
500
501
502
503

```python
model = dict(
    ...
    roi_head=dict(
        type='PartAggregationROIHead',
        num_classes=3,
        semantic_head=dict(
            type='PointwiseSemanticHead',
            in_channels=16,
            extra_width=0.2,
            seg_score_thr=0.3,
            num_classes=3,
            loss_seg=dict(
504
                type='mmdet.FocalLoss',
505
506
507
508
509
510
                use_sigmoid=True,
                reduction='sum',
                gamma=2.0,
                alpha=0.25,
                loss_weight=1.0),
            loss_part=dict(
511
512
513
                type='mmdet.CrossEntropyLoss',
                use_sigmoid=True,
                loss_weight=1.0)),
514
515
516
517
518
519
520
        seg_roi_extractor=dict(
            type='Single3DRoIAwareExtractor',
            roi_layer=dict(
                type='RoIAwarePool3d',
                out_size=14,
                max_pts_per_voxel=128,
                mode='max')),
521
        bbox_roi_extractor=dict(
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
            type='Single3DRoIAwareExtractor',
            roi_layer=dict(
                type='RoIAwarePool3d',
                out_size=14,
                max_pts_per_voxel=128,
                mode='avg')),
        bbox_head=dict(
            type='PartA2BboxHead',
            num_classes=3,
            seg_in_channels=16,
            part_in_channels=4,
            seg_conv_channels=[64, 64],
            part_conv_channels=[64, 64],
            merge_conv_channels=[128, 128],
            down_conv_channels=[128, 256],
            bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
            shared_fc_channels=[256, 512, 512, 512],
            cls_channels=[256, 256],
            reg_channels=[256, 256],
            dropout_ratio=0.1,
            roi_feat_size=14,
            with_corner_loss=True,
            loss_bbox=dict(
545
                type='mmdet.SmoothL1Loss',
546
547
548
549
                beta=1.0 / 9.0,
                reduction='sum',
                loss_weight=1.0),
            loss_cls=dict(
550
                type='mmdet.CrossEntropyLoss',
551
552
                use_sigmoid=True,
                reduction='sum',
553
                loss_weight=1.0))),
554
    ...
555
)
556
557
```

558
MMDetection 2.0 开始支持配置文件之间的继承,因此用户可以关注配置文件的修改。PartA2 Head 的第二阶段主要使用了新的 `PartAggregationROIHead``PartA2BboxHead`,需要根据对应模块的 `__init__` 函数来设置参数。
559

560
### 添加新的损失函数
561

562
假设您想要为检测框的回归添加一个新的损失函数 `MyLoss`。为了添加一个新的损失函数,用户需要在 `mmdet3d/models/losses/my_loss.py` 中实现该函数。装饰器 `weighted_loss` 能够保证对每个元素的损失进行加权平均。
563
564
565
566

```python
import torch
import torch.nn as nn
567
from mmdet.models.losses.utils import weighted_loss
568

569
from mmdet3d.registry import MODELS
570

571
572
573
574
575
576
577

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

578

579
@MODELS.register_module()
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox
```

601
接下来,用户需要在 `mmdet3d/models/losses/__init__.py` 添加该函数。
602
603
604
605
606

```python
from .my_loss import MyLoss, my_loss
```

607
或者在配置文件中添加以下代码以达到相同的目的。
608
609
610

```python
custom_imports=dict(
611
612
    imports=['mmdet3d.models.losses.my_loss'],
    allow_failed_imports=False)
613
614
```

615
为了使用该函数,用户需要修改 `loss_xxx` 域。由于 `MyLoss` 是用于回归的,您需要修改 head 中的 `loss_bbox` 域。
616
617

```python
618
loss_bbox=dict(type='MyLoss', loss_weight=1.0)
619
```