Unverified Commit 07590418 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Refactor]: Unified parameter initialization (#622)

* support 3dssd

* support one-stage method

* for lint

* support two_stage

* Support all methods

* remove init_cfg=[] in configs

* test

* support h3dnet

* fix lint error

* fix isort

* fix code style error

* fix imvotenet bug

* rename init_weight->init_weights

* clean comma

* fix test_apis does not init weights

* support newest mmdet and mmcv

* fix test_heads h3dnet bug

* rm *.swp

* remove the wrong code in build.yml

* fix ssn low map

* modify docs

* modified ssn init_config

* modify params in backbone pointnet2_sa_ssg

* add ssn direction init_cfg

* support segmentor

* add conv a=sqrt(5)

* Convmodule uses kaiming_init

* fix centerpointhead init bug

* add second conv2d init cfg

* add unittest to confirm the input is not be modified

* assert gt_bboxes_3d

* rm .swag

* modify docs mmdet version

* adopt fcosmono3d

* add fcos 3d original init method

* fix mmseg version

* add init cfg in fcos_mono3d.py

* merge newest master

* remove unused code

* modify focs config due to changes of resnet

* support imvoxelnet pointnet2

* modified the dependencies version

* support decode head

* fix inference bug

* modify the useless init_cfg

* fix multi_modality BC-breaking

* fix error blank

* modify docs error
parent 318499ac
import copy import copy
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
...@@ -15,7 +15,7 @@ from mmdet.core import build_bbox_coder, multi_apply ...@@ -15,7 +15,7 @@ from mmdet.core import build_bbox_coder, multi_apply
@HEADS.register_module() @HEADS.register_module()
class SeparateHead(nn.Module): class SeparateHead(BaseModule):
"""SeparateHead for CenterHead. """SeparateHead for CenterHead.
Args: Args:
...@@ -42,9 +42,11 @@ class SeparateHead(nn.Module): ...@@ -42,9 +42,11 @@ class SeparateHead(nn.Module):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
bias='auto', bias='auto',
init_cfg=None,
**kwargs): **kwargs):
super(SeparateHead, self).__init__() assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads self.heads = heads
self.init_bias = init_bias self.init_bias = init_bias
for head in self.heads: for head in self.heads:
...@@ -78,15 +80,15 @@ class SeparateHead(nn.Module): ...@@ -78,15 +80,15 @@ class SeparateHead(nn.Module):
self.__setattr__(head, conv_layers) self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self): def init_weights(self):
"""Initialize weights.""" """Initialize weights."""
super().init_weights()
for head in self.heads: for head in self.heads:
if head == 'heatmap': if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias) self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
else:
for m in self.__getattr__(head).modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
def forward(self, x): def forward(self, x):
"""Forward function for SepHead. """Forward function for SepHead.
...@@ -119,7 +121,7 @@ class SeparateHead(nn.Module): ...@@ -119,7 +121,7 @@ class SeparateHead(nn.Module):
@HEADS.register_module() @HEADS.register_module()
class DCNSeparateHead(nn.Module): class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead. r"""DCNSeparateHead for CenterHead.
.. code-block:: none .. code-block:: none
...@@ -154,8 +156,11 @@ class DCNSeparateHead(nn.Module): ...@@ -154,8 +156,11 @@ class DCNSeparateHead(nn.Module):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
bias='auto', bias='auto',
init_cfg=None,
**kwargs): **kwargs):
super(DCNSeparateHead, self).__init__() assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if 'heatmap' in heads: if 'heatmap' in heads:
heads.pop('heatmap') heads.pop('heatmap')
# feature adaptation with dcn # feature adaptation with dcn
...@@ -192,11 +197,13 @@ class DCNSeparateHead(nn.Module): ...@@ -192,11 +197,13 @@ class DCNSeparateHead(nn.Module):
head_conv=head_conv, head_conv=head_conv,
final_kernel=final_kernel, final_kernel=final_kernel,
bias=bias) bias=bias)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self): def init_weights(self):
"""Initialize weights.""" """Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias) self.cls_head[-1].bias.data.fill_(self.init_bias)
self.task_head.init_weights()
def forward(self, x): def forward(self, x):
"""Forward function for DCNSepHead. """Forward function for DCNSepHead.
...@@ -232,7 +239,7 @@ class DCNSeparateHead(nn.Module): ...@@ -232,7 +239,7 @@ class DCNSeparateHead(nn.Module):
@HEADS.register_module() @HEADS.register_module()
class CenterHead(nn.Module): class CenterHead(BaseModule):
"""CenterHead for CenterPoint. """CenterHead for CenterPoint.
Args: Args:
...@@ -280,8 +287,11 @@ class CenterHead(nn.Module): ...@@ -280,8 +287,11 @@ class CenterHead(nn.Module):
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
bias='auto', bias='auto',
norm_bbox=True): norm_bbox=True,
super(CenterHead, self).__init__() init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks] num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks] self.class_names = [t['class_names'] for t in tasks]
...@@ -316,11 +326,6 @@ class CenterHead(nn.Module): ...@@ -316,11 +326,6 @@ class CenterHead(nn.Module):
in_channels=share_conv_channel, heads=heads, num_cls=num_cls) in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head)) self.task_heads.append(builder.build_head(separate_head))
def init_weights(self):
"""Initialize weights."""
for task_head in self.task_heads:
task_head.init_weights()
def forward_single(self, x): def forward_single(self, x):
"""Forward function for CenterPoint. """Forward function for CenterPoint.
......
...@@ -74,6 +74,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -74,6 +74,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
loss_weight=1.0), loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
centerness_branch=(64, ), centerness_branch=(64, ),
init_cfg=None,
**kwargs): **kwargs):
self.regress_ranges = regress_ranges self.regress_ranges = regress_ranges
self.center_sampling = center_sampling self.center_sampling = center_sampling
...@@ -90,6 +91,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead): ...@@ -90,6 +91,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
loss_dir=loss_dir, loss_dir=loss_dir,
loss_attr=loss_attr, loss_attr=loss_attr,
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
init_cfg=init_cfg,
**kwargs) **kwargs)
self.loss_centerness = build_loss(loss_centerness) self.loss_centerness = build_loss(loss_centerness)
......
...@@ -32,8 +32,9 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -32,8 +32,9 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_thr=0.6, bbox_thr=0.6,
gamma=2.0, gamma=2.0,
alpha=0.5, alpha=0.5,
init_cfg=None,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(init_cfg=init_cfg, **kwargs)
self.pre_anchor_topk = pre_anchor_topk self.pre_anchor_topk = pre_anchor_topk
self.bbox_thr = bbox_thr self.bbox_thr = bbox_thr
self.gamma = gamma self.gamma = gamma
......
...@@ -75,12 +75,13 @@ class PartA2RPNHead(Anchor3DHead): ...@@ -75,12 +75,13 @@ class PartA2RPNHead(Anchor3DHead):
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict( loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)): loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2),
init_cfg=None):
super().__init__(num_classes, in_channels, train_cfg, test_cfg, super().__init__(num_classes, in_channels, train_cfg, test_cfg,
feat_channels, use_direction_classifier, feat_channels, use_direction_classifier,
anchor_generator, assigner_per_size, assign_per_class, anchor_generator, assigner_per_size, assign_per_class,
diff_rad_by_sin, dir_offset, dir_limit_offset, diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir) bbox_coder, loss_cls, loss_bbox, loss_dir, init_cfg)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self, def loss(self,
......
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init import warnings
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
...@@ -11,7 +13,7 @@ from .anchor3d_head import Anchor3DHead ...@@ -11,7 +13,7 @@ from .anchor3d_head import Anchor3DHead
@HEADS.register_module() @HEADS.register_module()
class BaseShapeHead(nn.Module): class BaseShapeHead(BaseModule):
"""Base Shape-aware Head in Shape Signature Network. """Base Shape-aware Head in Shape Signature Network.
Note: Note:
...@@ -48,8 +50,9 @@ class BaseShapeHead(nn.Module): ...@@ -48,8 +50,9 @@ class BaseShapeHead(nn.Module):
use_direction_classifier=True, use_direction_classifier=True,
conv_cfg=dict(type='Conv2d'), conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'), norm_cfg=dict(type='BN2d'),
bias=False): bias=False,
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_cls = num_cls self.num_cls = num_cls
self.num_base_anchors = num_base_anchors self.num_base_anchors = num_base_anchors
self.use_direction_classifier = use_direction_classifier self.use_direction_classifier = use_direction_classifier
...@@ -84,15 +87,36 @@ class BaseShapeHead(nn.Module): ...@@ -84,15 +87,36 @@ class BaseShapeHead(nn.Module):
if use_direction_classifier: if use_direction_classifier:
self.conv_dir_cls = nn.Conv2d(out_channels, num_base_anchors * 2, self.conv_dir_cls = nn.Conv2d(out_channels, num_base_anchors * 2,
1) 1)
if init_cfg is None:
def init_weights(self): if use_direction_classifier:
"""Initialize weights.""" self.init_cfg = dict(
bias_cls = bias_init_with_prob(0.01) type='Kaiming',
# shared conv layers have already been initialized by ConvModule layer='Conv2d',
normal_init(self.conv_cls, std=0.01, bias=bias_cls) override=[
normal_init(self.conv_reg, std=0.01) dict(type='Normal', name='conv_reg', std=0.01),
if self.use_direction_classifier: dict(
normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls) type='Normal',
name='conv_cls',
std=0.01,
bias_prob=0.01),
dict(
type='Normal',
name='conv_dir_cls',
std=0.01,
bias_prob=0.01)
])
else:
self.init_cfg = dict(
type='Kaiming',
layer='Conv2d',
override=[
dict(type='Normal', name='conv_reg', std=0.01),
dict(
type='Normal',
name='conv_cls',
std=0.01,
bias_prob=0.01)
])
def forward(self, x): def forward(self, x):
"""Forward function for SmallHead. """Forward function for SmallHead.
...@@ -149,10 +173,21 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -149,10 +173,21 @@ class ShapeAwareHead(Anchor3DHead):
:class:`Anchor3DHead`. :class:`Anchor3DHead`.
""" """
def __init__(self, tasks, assign_per_class=True, **kwargs): def __init__(self, tasks, assign_per_class=True, init_cfg=None, **kwargs):
self.tasks = tasks self.tasks = tasks
self.featmap_sizes = [] self.featmap_sizes = []
super().__init__(assign_per_class=assign_per_class, **kwargs) super().__init__(
assign_per_class=assign_per_class, init_cfg=init_cfg, **kwargs)
def init_weights(self):
if not self._is_init:
for m in self.heads:
if hasattr(m, 'init_weights'):
m.init_weights()
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
def _init_layers(self): def _init_layers(self):
"""Initialize neural network layers of the head.""" """Initialize neural network layers of the head."""
...@@ -175,11 +210,6 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -175,11 +210,6 @@ class ShapeAwareHead(Anchor3DHead):
self.heads.append(build_head(branch)) self.heads.append(build_head(branch))
cls_ptr += task['num_class'] cls_ptr += task['num_class']
def init_weights(self):
"""Initialize the weights of head."""
for head in self.heads:
head.init_weights()
def forward_single(self, x): def forward_single(self, x):
"""Forward function on a single-scale feature map. """Forward function on a single-scale feature map.
......
...@@ -58,7 +58,8 @@ class SSD3DHead(VoteHead): ...@@ -58,7 +58,8 @@ class SSD3DHead(VoteHead):
dir_res_loss=None, dir_res_loss=None,
size_res_loss=None, size_res_loss=None,
corner_loss=None, corner_loss=None,
vote_loss=None): vote_loss=None,
init_cfg=None):
super(SSD3DHead, self).__init__( super(SSD3DHead, self).__init__(
num_classes, num_classes,
bbox_coder, bbox_coder,
...@@ -75,7 +76,8 @@ class SSD3DHead(VoteHead): ...@@ -75,7 +76,8 @@ class SSD3DHead(VoteHead):
dir_res_loss=dir_res_loss, dir_res_loss=dir_res_loss,
size_class_loss=None, size_class_loss=None,
size_res_loss=size_res_loss, size_res_loss=size_res_loss,
semantic_loss=None) semantic_loss=None,
init_cfg=init_cfg)
self.corner_loss = build_loss(corner_loss) self.corner_loss = build_loss(corner_loss)
self.vote_loss = build_loss(vote_loss) self.vote_loss = build_loss(vote_loss)
......
import numpy as np import numpy as np
import torch import torch
from mmcv.runner import force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
...@@ -15,7 +14,7 @@ from .base_conv_bbox_head import BaseConvBboxHead ...@@ -15,7 +14,7 @@ from .base_conv_bbox_head import BaseConvBboxHead
@HEADS.register_module() @HEADS.register_module()
class VoteHead(nn.Module): class VoteHead(BaseModule):
r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_. r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
Args: Args:
...@@ -56,8 +55,9 @@ class VoteHead(nn.Module): ...@@ -56,8 +55,9 @@ class VoteHead(nn.Module):
size_class_loss=None, size_class_loss=None,
size_res_loss=None, size_res_loss=None,
semantic_loss=None, semantic_loss=None,
iou_loss=None): iou_loss=None,
super(VoteHead, self).__init__() init_cfg=None):
super(VoteHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -92,10 +92,6 @@ class VoteHead(nn.Module): ...@@ -92,10 +92,6 @@ class VoteHead(nn.Module):
num_cls_out_channels=self._get_cls_out_channels(), num_cls_out_channels=self._get_cls_out_channels(),
num_reg_out_channels=self._get_reg_out_channels()) num_reg_out_channels=self._get_reg_out_channels())
def init_weights(self):
"""Initialize weights of VoteHead."""
pass
def _get_cls_out_channels(self): def _get_cls_out_channels(self):
"""Return the channel number of classification outputs.""" """Return the channel number of classification outputs."""
# Class numbers (k) + objectness (2) # Class numbers (k) + objectness (2)
......
...@@ -23,13 +23,14 @@ class CenterPoint(MVXTwoStageDetector): ...@@ -23,13 +23,14 @@ class CenterPoint(MVXTwoStageDetector):
img_rpn_head=None, img_rpn_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
super(CenterPoint, super(CenterPoint,
self).__init__(pts_voxel_layer, pts_voxel_encoder, self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer, pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck, img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head, pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained) train_cfg, test_cfg, pretrained, init_cfg)
def extract_pts_feat(self, pts, img_feats, img_metas): def extract_pts_feat(self, pts, img_feats, img_metas):
"""Extract features of points.""" """Extract features of points."""
......
...@@ -20,7 +20,8 @@ class DynamicVoxelNet(VoxelNet): ...@@ -20,7 +20,8 @@ class DynamicVoxelNet(VoxelNet):
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
super(DynamicVoxelNet, self).__init__( super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer, voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder, voxel_encoder=voxel_encoder,
...@@ -31,7 +32,7 @@ class DynamicVoxelNet(VoxelNet): ...@@ -31,7 +32,7 @@ class DynamicVoxelNet(VoxelNet):
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, pretrained=pretrained,
) init_cfg=init_cfg)
def extract_feat(self, points, img_metas): def extract_feat(self, points, img_metas):
"""Extract features from points.""" """Extract features from points."""
......
...@@ -19,7 +19,8 @@ class H3DNet(TwoStage3DDetector): ...@@ -19,7 +19,8 @@ class H3DNet(TwoStage3DDetector):
roi_head=None, roi_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
super(H3DNet, self).__init__( super(H3DNet, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
...@@ -27,7 +28,8 @@ class H3DNet(TwoStage3DDetector): ...@@ -27,7 +28,8 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head, roi_head=roi_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained,
init_cfg=init_cfg)
def forward_train(self, def forward_train(self,
points, points,
......
import numpy as np import numpy as np
import torch import torch
from torch import nn as nn import warnings
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.models.utils import MLP from mmdet3d.models.utils import MLP
...@@ -69,9 +69,10 @@ class ImVoteNet(Base3DDetector): ...@@ -69,9 +69,10 @@ class ImVoteNet(Base3DDetector):
num_sampled_seed=None, num_sampled_seed=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
super(ImVoteNet, self).__init__() super(ImVoteNet, self).__init__(init_cfg=init_cfg)
# point branch # point branch
if pts_backbone is not None: if pts_backbone is not None:
...@@ -134,11 +135,7 @@ class ImVoteNet(Base3DDetector): ...@@ -134,11 +135,7 @@ class ImVoteNet(Base3DDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize model weights."""
super(ImVoteNet, self).init_weights(pretrained)
if pretrained is None: if pretrained is None:
img_pretrained = None img_pretrained = None
pts_pretrained = None pts_pretrained = None
...@@ -148,29 +145,26 @@ class ImVoteNet(Base3DDetector): ...@@ -148,29 +145,26 @@ class ImVoteNet(Base3DDetector):
else: else:
raise ValueError( raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}') f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
self.img_backbone.init_weights(pretrained=img_pretrained)
if self.with_img_neck:
if isinstance(self.img_neck, nn.Sequential):
for m in self.img_neck:
m.init_weights()
else:
self.img_neck.init_weights()
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head: if self.with_img_roi_head:
self.img_roi_head.init_weights(img_pretrained) if img_pretrained is not None:
if self.with_img_rpn: warnings.warn('DeprecationWarning: pretrained is a deprecated \
self.img_rpn_head.init_weights() key, please consider using init_cfg')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone: if self.with_pts_backbone:
self.pts_backbone.init_weights(pretrained=pts_pretrained) if img_pretrained is not None:
if self.with_pts_bbox: warnings.warn('DeprecationWarning: pretrained is a deprecated \
self.pts_bbox_head.init_weights() key, please consider using init_cfg')
if self.with_pts_neck: self.pts_backbone.init_cfg = dict(
if isinstance(self.pts_neck, nn.Sequential): type='Pretrained', checkpoint=pts_pretrained)
for m in self.pts_neck:
m.init_weights()
else:
self.pts_neck.init_weights()
def freeze_img_branch_params(self): def freeze_img_branch_params(self):
"""Freeze all image branch parameters.""" """Freeze all image branch parameters."""
......
...@@ -19,8 +19,9 @@ class ImVoxelNet(BaseDetector): ...@@ -19,8 +19,9 @@ class ImVoxelNet(BaseDetector):
anchor_generator, anchor_generator,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone) self.backbone = build_backbone(backbone)
self.neck = build_neck(neck) self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d) self.neck_3d = build_neck(neck_3d)
...@@ -31,20 +32,6 @@ class ImVoxelNet(BaseDetector): ...@@ -31,20 +32,6 @@ class ImVoxelNet(BaseDetector):
self.anchor_generator = build_anchor_generator(anchor_generator) self.anchor_generator = build_anchor_generator(anchor_generator)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights()
self.neck_3d.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img, img_metas): def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection. """Extract 3d features from the backbone -> fpn -> 3d projection.
......
import mmcv import mmcv
import torch import torch
import warnings
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from os import path as osp from os import path as osp
from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result, from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
...@@ -33,8 +33,9 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -33,8 +33,9 @@ class MVXTwoStageDetector(Base3DDetector):
img_rpn_head=None, img_rpn_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
super(MVXTwoStageDetector, self).__init__() init_cfg=None):
super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg)
if pts_voxel_layer: if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer) self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
...@@ -69,11 +70,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -69,11 +70,7 @@ class MVXTwoStageDetector(Base3DDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize model weights."""
super(MVXTwoStageDetector, self).init_weights(pretrained)
if pretrained is None: if pretrained is None:
img_pretrained = None img_pretrained = None
pts_pretrained = None pts_pretrained = None
...@@ -83,23 +80,26 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -83,23 +80,26 @@ class MVXTwoStageDetector(Base3DDetector):
else: else:
raise ValueError( raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}') f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
self.img_backbone.init_weights(pretrained=img_pretrained)
if self.with_pts_backbone:
self.pts_backbone.init_weights(pretrained=pts_pretrained)
if self.with_img_neck:
if isinstance(self.img_neck, nn.Sequential):
for m in self.img_neck:
m.init_weights()
else:
self.img_neck.init_weights()
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head: if self.with_img_roi_head:
self.img_roi_head.init_weights(img_pretrained) if img_pretrained is not None:
if self.with_img_rpn: warnings.warn('DeprecationWarning: pretrained is a deprecated \
self.img_rpn_head.init_weights() key, please consider using init_cfg')
if self.with_pts_bbox: self.img_roi_head.init_cfg = dict(
self.pts_bbox_head.init_weights() type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
@property @property
def with_img_shared_head(self): def with_img_shared_head(self):
......
...@@ -24,7 +24,8 @@ class PartA2(TwoStage3DDetector): ...@@ -24,7 +24,8 @@ class PartA2(TwoStage3DDetector):
roi_head=None, roi_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
init_cfg=None):
super(PartA2, self).__init__( super(PartA2, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
...@@ -33,7 +34,7 @@ class PartA2(TwoStage3DDetector): ...@@ -33,7 +34,7 @@ class PartA2(TwoStage3DDetector):
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, pretrained=pretrained,
) init_cfg=init_cfg)
self.voxel_layer = Voxelization(**voxel_layer) self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
......
from torch import nn as nn
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector from .base import Base3DDetector
...@@ -28,8 +26,9 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -28,8 +26,9 @@ class SingleStage3DDetector(Base3DDetector):
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
init_cfg=None,
pretrained=None): pretrained=None):
super(SingleStage3DDetector, self).__init__() super(SingleStage3DDetector, self).__init__(init_cfg)
self.backbone = build_backbone(backbone) self.backbone = build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = build_neck(neck) self.neck = build_neck(neck)
...@@ -38,19 +37,6 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -38,19 +37,6 @@ class SingleStage3DDetector(Base3DDetector):
self.bbox_head = build_head(bbox_head) self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize weights of detector."""
super(SingleStage3DDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, points, img_metas=None): def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck. """Directly extract features from the backbone+neck.
......
...@@ -14,10 +14,12 @@ class SSD3DNet(VoteNet): ...@@ -14,10 +14,12 @@ class SSD3DNet(VoteNet):
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
init_cfg=None,
pretrained=None): pretrained=None):
super(SSD3DNet, self).__init__( super(SSD3DNet, self).__init__(
backbone=backbone, backbone=backbone,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=init_cfg,
pretrained=pretrained) pretrained=pretrained)
...@@ -14,12 +14,14 @@ class VoteNet(SingleStage3DDetector): ...@@ -14,12 +14,14 @@ class VoteNet(SingleStage3DDetector):
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
init_cfg=None,
pretrained=None): pretrained=None):
super(VoteNet, self).__init__( super(VoteNet, self).__init__(
backbone=backbone, backbone=backbone,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=None,
pretrained=pretrained) pretrained=pretrained)
def forward_train(self, def forward_train(self,
......
...@@ -22,6 +22,7 @@ class VoxelNet(SingleStage3DDetector): ...@@ -22,6 +22,7 @@ class VoxelNet(SingleStage3DDetector):
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
init_cfg=None,
pretrained=None): pretrained=None):
super(VoxelNet, self).__init__( super(VoxelNet, self).__init__(
backbone=backbone, backbone=backbone,
...@@ -29,8 +30,8 @@ class VoxelNet(SingleStage3DDetector): ...@@ -29,8 +30,8 @@ class VoxelNet(SingleStage3DDetector):
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained, init_cfg=init_cfg,
) pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer) self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder) self.middle_encoder = builder.build_middle_encoder(middle_encoder)
......
import torch import torch
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -96,7 +97,7 @@ def point_sample( ...@@ -96,7 +97,7 @@ def point_sample(
@FUSION_LAYERS.register_module() @FUSION_LAYERS.register_module()
class PointFusion(nn.Module): class PointFusion(BaseModule):
"""Fuse image features from multi-scale features. """Fuse image features from multi-scale features.
Args: Args:
...@@ -138,6 +139,7 @@ class PointFusion(nn.Module): ...@@ -138,6 +139,7 @@ class PointFusion(nn.Module):
conv_cfg=None, conv_cfg=None,
norm_cfg=None, norm_cfg=None,
act_cfg=None, act_cfg=None,
init_cfg=None,
activate_out=True, activate_out=True,
fuse_out=False, fuse_out=False,
dropout_ratio=0, dropout_ratio=0,
...@@ -145,7 +147,7 @@ class PointFusion(nn.Module): ...@@ -145,7 +147,7 @@ class PointFusion(nn.Module):
align_corners=True, align_corners=True,
padding_mode='zeros', padding_mode='zeros',
lateral_conv=True): lateral_conv=True):
super(PointFusion, self).__init__() super(PointFusion, self).__init__(init_cfg=init_cfg)
if isinstance(img_levels, int): if isinstance(img_levels, int):
img_levels = [img_levels] img_levels = [img_levels]
if isinstance(img_channels, int): if isinstance(img_channels, int):
...@@ -200,14 +202,11 @@ class PointFusion(nn.Module): ...@@ -200,14 +202,11 @@ class PointFusion(nn.Module):
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(inplace=False)) nn.ReLU(inplace=False))
self.init_weights() if init_cfg is None:
self.init_cfg = [
# default init_weights for conv(msra) and norm in ConvModule dict(type='Xavier', layer='Conv2d', distribution='uniform'),
def init_weights(self): dict(type='Xavier', layer='Linear', distribution='uniform')
"""Initialize the weights of modules.""" ]
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, distribution='uniform')
def forward(self, img_feats, pts, pts_feats, img_metas): def forward(self, img_feats, pts, pts_feats, img_metas):
"""Forward function. """Forward function.
......
import torch import torch
from mmcv.runner import auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv from mmdet3d.ops import spconv as spconv
...@@ -8,7 +7,7 @@ from ..builder import MIDDLE_ENCODERS ...@@ -8,7 +7,7 @@ from ..builder import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module() @MIDDLE_ENCODERS.register_module()
class SparseUNet(nn.Module): class SparseUNet(BaseModule):
r"""SparseUNet for PartA^2. r"""SparseUNet for PartA^2.
See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details. See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
...@@ -40,8 +39,9 @@ class SparseUNet(nn.Module): ...@@ -40,8 +39,9 @@ class SparseUNet(nn.Module):
1)), 1)),
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16), decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)), (16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))): decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.sparse_shape = sparse_shape self.sparse_shape = sparse_shape
self.in_channels = in_channels self.in_channels = in_channels
self.order = order self.order = order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment