Unverified Commit 07590418 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Refactor]: Unified parameter initialization (#622)

* support 3dssd

* support one-stage method

* for lint

* support two_stage

* Support all methods

* remove init_cfg=[] in configs

* test

* support h3dnet

* fix lint error

* fix isort

* fix code style error

* fix imvotenet bug

* rename init_weight->init_weights

* clean comma

* fix test_apis does not init weights

* support newest mmdet and mmcv

* fix test_heads h3dnet bug

* rm *.swp

* remove the wrong code in build.yml

* fix ssn low map

* modify docs

* modified ssn init_config

* modify params in backbone pointnet2_sa_ssg

* add ssn direction init_cfg

* support segmentor

* add conv a=sqrt(5)

* Convmodule uses kaiming_init

* fix centerpointhead init bug

* add second conv2d init cfg

* add unittest to confirm the input is not be modified

* assert gt_bboxes_3d

* rm .swag

* modify docs mmdet version

* adopt fcosmono3d

* add fcos 3d original init method

* fix mmseg version

* add init cfg in fcos_mono3d.py

* merge newest master

* remove unused code

* modify focs config due to changes of resnet

* support imvoxelnet pointnet2

* modified the dependencies version

* support decode head

* fix inference bug

* modify the useless init_cfg

* fix multi_modality BC-breaking

* fix error blank

* modify docs error
parent 318499ac
...@@ -91,8 +91,8 @@ jobs: ...@@ -91,8 +91,8 @@ jobs:
- name: Install mmdet3d dependencies - name: Install mmdet3d dependencies
run: | run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
pip install mmdet==2.11.0 pip install mmdet==2.14.0
pip install mmsegmentation==0.14.0 pip install mmsegmentation==0.14.1
pip install -r requirements.txt pip install -r requirements.txt
- name: Build and install - name: Build and install
run: | run: |
......
...@@ -16,7 +16,6 @@ model = dict( ...@@ -16,7 +16,6 @@ model = dict(
out_channels=256, out_channels=256,
start_level=1, start_level=1,
add_extra_convs=True, add_extra_convs=True,
extra_convs_on_inputs=False, # use P5
num_outs=5, num_outs=5,
relu_before_extra_convs=True), relu_before_extra_convs=True),
bbox_head=dict( bbox_head=dict(
......
...@@ -99,8 +99,8 @@ model = dict( ...@@ -99,8 +99,8 @@ model = dict(
nms_across_levels=False, nms_across_levels=False,
nms_pre=1000, nms_pre=1000,
nms_post=1000, nms_post=1000,
max_num=1000, max_per_img=1000,
nms_thr=0.7, nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0), min_bbox_size=0),
img_rcnn=dict( img_rcnn=dict(
score_thr=0.05, score_thr=0.05,
......
model = dict( model = dict(
type='ImVoxelNet', type='ImVoxelNet',
pretrained='torchvision://resnet50',
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,
...@@ -9,6 +8,7 @@ model = dict( ...@@ -9,6 +8,7 @@ model = dict(
frozen_stages=1, frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False), norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True, norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'), style='pytorch'),
neck=dict( neck=dict(
type='FPN', type='FPN',
......
...@@ -4,6 +4,15 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet ...@@ -4,6 +4,15 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet
## MMDetection3D 0.15.0 ## MMDetection3D 0.15.0
### Unified parameter initialization
To unify the parameter initialization in OpenMMLab projects, MMCV supports `BaseModule` that accepts `init_cfg` to allow the modules' parameters initialized in a flexible and unified manner. Now the users need to explicitly call `model.init_weights()` in the training script to initialize the model (as in [here](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/train.py#L183), previously this was handled by the detector. Please refer to PR #622 for details.
### BackgroundPointsFilter
We modified the dataset aumentation function `BackgroundPointsFilter`(in [here](https://github.com/open-mmlab/mmdetection3d/blob/mmdet3d/datasets/pipelines/transforms_3d.py#L1101)). In previous version of MMdetection3D, `BackgroundPointsFilter` changes the gt_bboxes_3d's bottom center to the gravity center. In MMDetection3D 0.15.0,
`BackgroundPointsFilter` will not change it. Please refer to PR #609 for details.
### Enhance `IndoorPatchPointSample` transform ### Enhance `IndoorPatchPointSample` transform
We enhance the pipeline function `IndoorPatchPointSample` used in point cloud segmentation task by adding more choices for patch selection. Also, we plan to remove the unused parameter `sample_rate` in the future. Please modify the code as well as the config files accordingly if you use this transform. We enhance the pipeline function `IndoorPatchPointSample` used in point cloud segmentation task by adding more choices for patch selection. Also, we plan to remove the unused parameter `sample_rate` in the future. Please modify the code as well as the config files accordingly if you use this transform.
......
...@@ -12,8 +12,8 @@ The required versions of MMCV, MMDetection and MMSegmentation for different vers ...@@ -12,8 +12,8 @@ The required versions of MMCV, MMDetection and MMSegmentation for different vers
| MMDetection3D version | MMDetection version | MMSegmentation version | MMCV version | | MMDetection3D version | MMDetection version | MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:|:-------------------:|:-------------------:| |:-------------------:|:-------------------:|:-------------------:|:-------------------:|
| master | mmdet>=2.10.0, <=2.11.0| mmseg==0.14.0 | mmcv-full>=1.3.1, <=1.4| | master | mmdet>=2.12.0 | mmseg>=0.14.1 | mmcv-full>=1.3.2, <=1.4|
| 0.14.0 | mmdet>=2.10.0, <=2.11.0| mmseg==0.14.0 | mmcv-full>=1.3.1, <=1.4| | 0.14.0 | mmdet>=2.10.0, <=2.11.0| mmseg>=0.13.0 | mmcv-full>=1.3.1, <=1.4|
| 0.13.0 | mmdet>=2.10.0, <=2.11.0| Not required | mmcv-full>=1.2.4, <=1.4| | 0.13.0 | mmdet>=2.10.0, <=2.11.0| Not required | mmcv-full>=1.2.4, <=1.4|
| 0.12.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4| | 0.12.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4|
| 0.11.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4| | 0.11.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4|
......
...@@ -33,9 +33,6 @@ class HardVFE(nn.Module): ...@@ -33,9 +33,6 @@ class HardVFE(nn.Module):
def forward(self, x): # should return a tuple def forward(self, x): # should return a tuple
pass pass
def init_weights(self, pretrained=None):
pass
``` ```
#### 2. Import the module #### 2. Import the module
...@@ -83,16 +80,13 @@ from ..builder import BACKBONES ...@@ -83,16 +80,13 @@ from ..builder import BACKBONES
@BACKBONES.register_module() @BACKBONES.register_module()
class SECOND(nn.Module): class SECOND(BaseModule):
def __init__(self, arg1, arg2): def __init__(self, arg1, arg2):
pass pass
def forward(self, x): # should return a tuple def forward(self, x): # should return a tuple
pass pass
def init_weights(self, pretrained=None):
pass
``` ```
#### 2. Import the module #### 2. Import the module
...@@ -135,7 +129,7 @@ Create a new file `mmdet3d/models/necks/second_fpn.py`. ...@@ -135,7 +129,7 @@ Create a new file `mmdet3d/models/necks/second_fpn.py`.
from ..builder import NECKS from ..builder import NECKS
@NECKS.register @NECKS.register
class SECONDFPN(nn.Module): class SECONDFPN(BaseModule):
def __init__(self, def __init__(self,
in_channels=[128, 128, 256], in_channels=[128, 128, 256],
...@@ -144,7 +138,8 @@ class SECONDFPN(nn.Module): ...@@ -144,7 +138,8 @@ class SECONDFPN(nn.Module):
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False), upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False), conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False): use_conv_for_no_stride=False,
init_cfg=None):
pass pass
def forward(self, X): def forward(self, X):
...@@ -198,7 +193,7 @@ from mmdet.models.builder import HEADS ...@@ -198,7 +193,7 @@ from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead from .bbox_head import BBoxHead
@HEADS.register_module() @HEADS.register_module()
class PartA2BboxHead(nn.Module): class PartA2BboxHead(BaseModule):
"""PartA2 RoI head.""" """PartA2 RoI head."""
def __init__(self, def __init__(self,
...@@ -224,11 +219,9 @@ class PartA2BboxHead(nn.Module): ...@@ -224,11 +219,9 @@ class PartA2BboxHead(nn.Module):
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
reduction='none', reduction='none',
loss_weight=1.0)): loss_weight=1.0),
super(PartA2BboxHead, self).__init__() init_cfg=None):
super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)
def init_weights(self):
# conv layers are already initialized by ConvModule
def forward(self, seg_feats, part_feats): def forward(self, seg_feats, part_feats):
...@@ -242,7 +235,7 @@ from torch import nn as nn ...@@ -242,7 +235,7 @@ from torch import nn as nn
@HEADS.register_module() @HEADS.register_module()
class Base3DRoIHead(nn.Module, metaclass=ABCMeta): class Base3DRoIHead(BaseModule, metaclass=ABCMeta):
"""Base class for 3d RoIHeads.""" """Base class for 3d RoIHeads."""
def __init__(self, def __init__(self,
...@@ -250,7 +243,8 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -250,7 +243,8 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
mask_roi_extractor=None, mask_roi_extractor=None,
mask_head=None, mask_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
init_cfg=None):
@property @property
def with_bbox(self): def with_bbox(self):
...@@ -333,9 +327,13 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -333,9 +327,13 @@ class PartAggregationROIHead(Base3DRoIHead):
part_roi_extractor=None, part_roi_extractor=None,
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
init_cfg=None):
super(PartAggregationROIHead, self).__init__( super(PartAggregationROIHead, self).__init__(
bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg) bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
assert semantic_head is not None assert semantic_head is not None
self.semantic_head = build_head(semantic_head) self.semantic_head = build_head(semantic_head)
......
...@@ -17,7 +17,7 @@ def digit_version(version_str): ...@@ -17,7 +17,7 @@ def digit_version(version_str):
return digit_version return digit_version
mmcv_minimum_version = '1.3.1' mmcv_minimum_version = '1.3.8'
mmcv_maximum_version = '1.4.0' mmcv_maximum_version = '1.4.0'
mmcv_version = digit_version(mmcv.__version__) mmcv_version = digit_version(mmcv.__version__)
...@@ -27,8 +27,8 @@ assert (mmcv_version >= digit_version(mmcv_minimum_version) ...@@ -27,8 +27,8 @@ assert (mmcv_version >= digit_version(mmcv_minimum_version)
f'MMCV=={mmcv.__version__} is used but incompatible. ' \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
mmdet_minimum_version = '2.10.0' mmdet_minimum_version = '2.14.0'
mmdet_maximum_version = '2.11.0' mmdet_maximum_version = '3.0.0'
mmdet_version = digit_version(mmdet.__version__) mmdet_version = digit_version(mmdet.__version__)
assert (mmdet_version >= digit_version(mmdet_minimum_version) assert (mmdet_version >= digit_version(mmdet_minimum_version)
and mmdet_version <= digit_version(mmdet_maximum_version)), \ and mmdet_version <= digit_version(mmdet_maximum_version)), \
...@@ -36,8 +36,8 @@ assert (mmdet_version >= digit_version(mmdet_minimum_version) ...@@ -36,8 +36,8 @@ assert (mmdet_version >= digit_version(mmdet_minimum_version)
f'Please install mmdet>={mmdet_minimum_version}, ' \ f'Please install mmdet>={mmdet_minimum_version}, ' \
f'<={mmdet_maximum_version}.' f'<={mmdet_maximum_version}.'
mmseg_minimum_version = '0.14.0' mmseg_minimum_version = '0.14.1'
mmseg_maximum_version = '0.14.0' mmseg_maximum_version = '1.0.0'
mmseg_version = digit_version(mmseg.__version__) mmseg_version = digit_version(mmseg.__version__)
assert (mmseg_version >= digit_version(mmseg_minimum_version) assert (mmseg_version >= digit_version(mmseg_minimum_version)
and mmseg_version <= digit_version(mmseg_maximum_version)), \ and mmseg_version <= digit_version(mmseg_maximum_version)), \
......
import warnings
from abc import ABCMeta from abc import ABCMeta
from mmcv.runner import load_checkpoint from mmcv.runner import BaseModule
from torch import nn as nn
class BasePointNet(nn.Module, metaclass=ABCMeta): class BasePointNet(BaseModule, metaclass=ABCMeta):
"""Base class for PointNet.""" """Base class for PointNet."""
def __init__(self): def __init__(self, init_cfg=None, pretrained=None):
super(BasePointNet, self).__init__() super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False self.fp16_enabled = False
assert not (init_cfg and pretrained), \
def init_weights(self, pretrained=None): 'init_cfg and pretrained cannot be setting at the same time'
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger warnings.warn('DeprecationWarning: pretrained is a deprecated, '
logger = get_root_logger() 'please use "init_cfg" instead')
load_checkpoint(self, pretrained, strict=False, logger=logger) self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@staticmethod @staticmethod
def _split_point_feats(points): def _split_point_feats(points):
......
import copy import copy
import torch import torch
import warnings
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16, load_checkpoint from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone from mmdet.models import BACKBONES, build_backbone
@BACKBONES.register_module() @BACKBONES.register_module()
class MultiBackbone(nn.Module): class MultiBackbone(BaseModule):
"""MultiBackbone with different configs. """MultiBackbone with different configs.
Args: Args:
...@@ -31,8 +32,10 @@ class MultiBackbone(nn.Module): ...@@ -31,8 +32,10 @@ class MultiBackbone(nn.Module):
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01), norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'), suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs): **kwargs):
super().__init__() super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list) assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict): if isinstance(backbones, dict):
backbones_list = [] backbones_list = []
...@@ -77,14 +80,12 @@ class MultiBackbone(nn.Module): ...@@ -77,14 +80,12 @@ class MultiBackbone(nn.Module):
bias=True, bias=True,
inplace=True)) inplace=True))
def init_weights(self, pretrained=None): assert not (init_cfg and pretrained), \
"""Initialize the weights of PointNet++ backbone.""" 'init_cfg and pretrained cannot be setting at the same time'
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger warnings.warn('DeprecationWarning: pretrained is a deprecated, '
logger = get_root_logger() 'please use "init_cfg" instead')
load_checkpoint(self, pretrained, strict=False, logger=logger) self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@auto_fp16() @auto_fp16()
def forward(self, points): def forward(self, points):
......
...@@ -57,8 +57,8 @@ class NoStemRegNet(RegNet): ...@@ -57,8 +57,8 @@ class NoStemRegNet(RegNet):
(1, 1008, 1, 1) (1, 1008, 1, 1)
""" """
def __init__(self, arch, **kwargs): def __init__(self, arch, init_cfg=None, **kwargs):
super(NoStemRegNet, self).__init__(arch, **kwargs) super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)
def _make_stem_layer(self, in_channels, base_channels): def _make_stem_layer(self, in_channels, base_channels):
"""Override the original function that do not initialize a stem layer """Override the original function that do not initialize a stem layer
......
...@@ -56,8 +56,9 @@ class PointNet2SAMSG(BasePointNet): ...@@ -56,8 +56,9 @@ class PointNet2SAMSG(BasePointNet):
type='PointSAModuleMSG', type='PointSAModuleMSG',
pool_mod='max', pool_mod='max',
use_xyz=True, use_xyz=True,
normalize_xyz=False)): normalize_xyz=False),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
self.out_indices = out_indices self.out_indices = out_indices
assert max(out_indices) < self.num_sa assert max(out_indices) < self.num_sa
......
...@@ -43,8 +43,9 @@ class PointNet2SASSG(BasePointNet): ...@@ -43,8 +43,9 @@ class PointNet2SASSG(BasePointNet):
type='PointSAModule', type='PointSAModule',
pool_mod='max', pool_mod='max',
use_xyz=True, use_xyz=True,
normalize_xyz=True)): normalize_xyz=True),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels) self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels) self.num_fp = len(fp_channels)
......
import warnings
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
@BACKBONES.register_module() @BACKBONES.register_module()
class SECOND(nn.Module): class SECOND(BaseModule):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet. """Backbone network for SECOND/PointPillars/PartA2/MVXNet.
Args: Args:
...@@ -24,8 +25,10 @@ class SECOND(nn.Module): ...@@ -24,8 +25,10 @@ class SECOND(nn.Module):
layer_nums=[3, 5, 5], layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2], layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)): conv_cfg=dict(type='Conv2d', bias=False),
super(SECOND, self).__init__() init_cfg=None,
pretrained=None):
super(SECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(layer_nums) assert len(layer_strides) == len(layer_nums)
assert len(out_channels) == len(layer_nums) assert len(out_channels) == len(layer_nums)
...@@ -61,14 +64,14 @@ class SECOND(nn.Module): ...@@ -61,14 +64,14 @@ class SECOND(nn.Module):
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def init_weights(self, pretrained=None): assert not (init_cfg and pretrained), \
"""Initialize weights of the 2D backbone.""" 'init_cfg and pretrained cannot be setting at the same time'
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str): if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger warnings.warn('DeprecationWarning: pretrained is a deprecated, '
logger = get_root_logger() 'please use "init_cfg" instead')
load_checkpoint(self, pretrained, strict=False, logger=logger) self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
else:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def forward(self, x): def forward(self, x):
"""Forward function. """Forward function.
......
import warnings import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry from mmcv.utils import Registry
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build) ROI_EXTRACTORS, SHARED_HEADS)
from mmseg.models.builder import SEGMENTORS from mmseg.models.builder import SEGMENTORS
VOXEL_ENCODERS = Registry('voxel_encoder') MODELS = Registry('models', parent=MMCV_MODELS)
MIDDLE_ENCODERS = Registry('middle_encoder')
FUSION_LAYERS = Registry('fusion_layer') VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone.""" """Build backbone."""
return build(cfg, BACKBONES) return BACKBONES.build(cfg)
def build_neck(cfg): def build_neck(cfg):
"""Build neck.""" """Build neck."""
return build(cfg, NECKS) return NECKS.build(cfg)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
"""Build RoI feature extractor.""" """Build RoI feature extractor."""
return build(cfg, ROI_EXTRACTORS) return ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg): def build_shared_head(cfg):
"""Build shared head of detector.""" """Build shared head of detector."""
return build(cfg, SHARED_HEADS) return SHARED_HEADS.build(cfg)
def build_head(cfg): def build_head(cfg):
"""Build head.""" """Build head."""
return build(cfg, HEADS) return HEADS.build(cfg)
def build_loss(cfg): def build_loss(cfg):
"""Build loss function.""" """Build loss function."""
return build(cfg, LOSSES) return LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
...@@ -50,7 +53,8 @@ def build_detector(cfg, train_cfg=None, test_cfg=None): ...@@ -50,7 +53,8 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field ' 'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \ assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field ' 'test_cfg specified in both outer field and model field '
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None): def build_segmentor(cfg, train_cfg=None, test_cfg=None):
...@@ -63,7 +67,8 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None): ...@@ -63,7 +67,8 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field ' 'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \ assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field ' 'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_model(cfg, train_cfg=None, test_cfg=None): def build_model(cfg, train_cfg=None, test_cfg=None):
...@@ -80,14 +85,14 @@ def build_model(cfg, train_cfg=None, test_cfg=None): ...@@ -80,14 +85,14 @@ def build_model(cfg, train_cfg=None, test_cfg=None):
def build_voxel_encoder(cfg): def build_voxel_encoder(cfg):
"""Build voxel encoder.""" """Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS) return VOXEL_ENCODERS.build(cfg)
def build_middle_encoder(cfg): def build_middle_encoder(cfg):
"""Build middle level encoder.""" """Build middle level encoder."""
return build(cfg, MIDDLE_ENCODERS) return MIDDLE_ENCODERS.build(cfg)
def build_fusion_layer(cfg): def build_fusion_layer(cfg):
"""Build fusion layer.""" """Build fusion layer."""
return build(cfg, FUSION_LAYERS) return FUSION_LAYERS.build(cfg)
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from mmcv.cnn import normal_init from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32 from mmcv.runner import BaseModule, auto_fp16, force_fp32
from torch import nn as nn from torch import nn as nn
from mmseg.models.builder import build_loss from mmseg.models.builder import build_loss
class Base3DDecodeHead(nn.Module, metaclass=ABCMeta): class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead. """Base class for BaseDecodeHead.
Args: Args:
...@@ -37,8 +37,9 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta): ...@@ -37,8 +37,9 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
use_sigmoid=False, use_sigmoid=False,
class_weight=None, class_weight=None,
loss_weight=1.0), loss_weight=1.0),
ignore_index=255): ignore_index=255,
super(Base3DDecodeHead, self).__init__() init_cfg=None):
super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
self.channels = channels self.channels = channels
self.num_classes = num_classes self.num_classes = num_classes
self.dropout_ratio = dropout_ratio self.dropout_ratio = dropout_ratio
...@@ -57,6 +58,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta): ...@@ -57,6 +58,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
def init_weights(self): def init_weights(self):
"""Initialize weights of classification layer.""" """Initialize weights of classification layer."""
super().init_weights()
normal_init(self.conv_seg, mean=0, std=0.01) normal_init(self.conv_seg, mean=0, std=0.01)
@auto_fp16() @auto_fp16()
......
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import bias_init_with_prob, normal_init from mmcv.runner import BaseModule, force_fp32
from mmcv.runner import force_fp32
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period, from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
...@@ -14,7 +13,7 @@ from .train_mixins import AnchorTrainMixin ...@@ -14,7 +13,7 @@ from .train_mixins import AnchorTrainMixin
@HEADS.register_module() @HEADS.register_module()
class Anchor3DHead(nn.Module, AnchorTrainMixin): class Anchor3DHead(BaseModule, AnchorTrainMixin):
"""Anchor head for SECOND/PointPillars/MVXNet/PartA2. """Anchor head for SECOND/PointPillars/MVXNet/PartA2.
Args: Args:
...@@ -67,8 +66,9 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -67,8 +66,9 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict( loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)): loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
self.feat_channels = feat_channels self.feat_channels = feat_channels
...@@ -103,6 +103,14 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -103,6 +103,14 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self._init_layers() self._init_layers()
self._init_assigner_sampler() self._init_assigner_sampler()
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))
def _init_assigner_sampler(self): def _init_assigner_sampler(self):
"""Initialize the target assigner and sampler of the head.""" """Initialize the target assigner and sampler of the head."""
if self.train_cfg is None: if self.train_cfg is None:
...@@ -129,12 +137,6 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin): ...@@ -129,12 +137,6 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self.conv_dir_cls = nn.Conv2d(self.feat_channels, self.conv_dir_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * 2, 1) self.num_anchors * 2, 1)
def init_weights(self):
"""Initialize the weights of head."""
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x): def forward_single(self, x):
"""Forward function on a single-scale feature map. """Forward function on a single-scale feature map.
......
...@@ -109,8 +109,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead): ...@@ -109,8 +109,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
conv_cfg=None, conv_cfg=None,
norm_cfg=None, norm_cfg=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
super(AnchorFreeMono3DHead, self).__init__() init_cfg=None):
super(AnchorFreeMono3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.cls_out_channels = num_classes self.cls_out_channels = num_classes
self.in_channels = in_channels self.in_channels = in_channels
......
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet.models.builder import HEADS from mmdet.models.builder import HEADS
@HEADS.register_module() @HEADS.register_module()
class BaseConvBboxHead(nn.Module): class BaseConvBboxHead(BaseModule):
r"""More general bbox head, with shared conv layers and two optional r"""More general bbox head, with shared conv layers and two optional
separated branches. separated branches.
...@@ -28,9 +29,11 @@ class BaseConvBboxHead(nn.Module): ...@@ -28,9 +29,11 @@ class BaseConvBboxHead(nn.Module):
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
bias='auto', bias='auto',
init_cfg=None,
*args, *args,
**kwargs): **kwargs):
super(BaseConvBboxHead, self).__init__(*args, **kwargs) super(BaseConvBboxHead, self).__init__(
init_cfg=init_cfg, *args, **kwargs)
assert in_channels > 0 assert in_channels > 0
assert num_cls_out_channels > 0 assert num_cls_out_channels > 0
assert num_reg_out_channels > 0 assert num_reg_out_channels > 0
...@@ -98,10 +101,6 @@ class BaseConvBboxHead(nn.Module): ...@@ -98,10 +101,6 @@ class BaseConvBboxHead(nn.Module):
inplace=True)) inplace=True))
return conv_layers return conv_layers
def init_weights(self):
# conv layers are already initialized by ConvModule
pass
def forward(self, feats): def forward(self, feats):
"""Forward. """Forward.
......
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from torch import nn as nn from mmcv.runner import BaseModule
class BaseMono3DDenseHead(nn.Module, metaclass=ABCMeta): class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Base class for Monocular 3D DenseHeads.""" """Base class for Monocular 3D DenseHeads."""
def __init__(self): def __init__(self, init_cfg=None):
super(BaseMono3DDenseHead, self).__init__() super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg)
@abstractmethod @abstractmethod
def loss(self, **kwargs): def loss(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment