Unverified Commit 07590418 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Refactor]: Unified parameter initialization (#622)

* support 3dssd

* support one-stage method

* for lint

* support two_stage

* Support all methods

* remove init_cfg=[] in configs

* test

* support h3dnet

* fix lint error

* fix isort

* fix code style error

* fix imvotenet bug

* rename init_weight->init_weights

* clean comma

* fix test_apis does not init weights

* support newest mmdet and mmcv

* fix test_heads h3dnet bug

* rm *.swp

* remove the wrong code in build.yml

* fix ssn low map

* modify docs

* modified ssn init_config

* modify params in backbone pointnet2_sa_ssg

* add ssn direction init_cfg

* support segmentor

* add conv a=sqrt(5)

* Convmodule uses kaiming_init

* fix centerpointhead init bug

* add second conv2d init cfg

* add unittest to confirm the input is not be modified

* assert gt_bboxes_3d

* rm .swag

* modify docs mmdet version

* adopt fcosmono3d

* add fcos 3d original init method

* fix mmseg version

* add init cfg in fcos_mono3d.py

* merge newest master

* remove unused code

* modify focs config due to changes of resnet

* support imvoxelnet pointnet2

* modified the dependencies version

* support decode head

* fix inference bug

* modify the useless init_cfg

* fix multi_modality BC-breaking

* fix error blank

* modify docs error
parent 318499ac
import copy
import numpy as np
import torch
from mmcv.cnn import ConvModule, build_conv_layer, kaiming_init
from mmcv.runner import force_fp32
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32
from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
......@@ -15,7 +15,7 @@ from mmdet.core import build_bbox_coder, multi_apply
@HEADS.register_module()
class SeparateHead(nn.Module):
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
......@@ -42,9 +42,11 @@ class SeparateHead(nn.Module):
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
super(SeparateHead, self).__init__()
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
......@@ -78,15 +80,15 @@ class SeparateHead(nn.Module):
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
else:
for m in self.__getattr__(head).modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
def forward(self, x):
"""Forward function for SepHead.
......@@ -119,7 +121,7 @@ class SeparateHead(nn.Module):
@HEADS.register_module()
class DCNSeparateHead(nn.Module):
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
......@@ -154,8 +156,11 @@ class DCNSeparateHead(nn.Module):
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
super(DCNSeparateHead, self).__init__()
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if 'heatmap' in heads:
heads.pop('heatmap')
# feature adaptation with dcn
......@@ -192,11 +197,13 @@ class DCNSeparateHead(nn.Module):
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
self.task_head.init_weights()
def forward(self, x):
"""Forward function for DCNSepHead.
......@@ -232,7 +239,7 @@ class DCNSeparateHead(nn.Module):
@HEADS.register_module()
class CenterHead(nn.Module):
class CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
......@@ -280,8 +287,11 @@ class CenterHead(nn.Module):
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True):
super(CenterHead, self).__init__()
norm_bbox=True,
init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks]
......@@ -316,11 +326,6 @@ class CenterHead(nn.Module):
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))
def init_weights(self):
"""Initialize weights."""
for task_head in self.task_heads:
task_head.init_weights()
def forward_single(self, x):
"""Forward function for CenterPoint.
......
......@@ -74,6 +74,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
centerness_branch=(64, ),
init_cfg=None,
**kwargs):
self.regress_ranges = regress_ranges
self.center_sampling = center_sampling
......@@ -90,6 +91,7 @@ class FCOSMono3DHead(AnchorFreeMono3DHead):
loss_dir=loss_dir,
loss_attr=loss_attr,
norm_cfg=norm_cfg,
init_cfg=init_cfg,
**kwargs)
self.loss_centerness = build_loss(loss_centerness)
......
......@@ -32,8 +32,9 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_thr=0.6,
gamma=2.0,
alpha=0.5,
init_cfg=None,
**kwargs):
super().__init__(**kwargs)
super().__init__(init_cfg=init_cfg, **kwargs)
self.pre_anchor_topk = pre_anchor_topk
self.bbox_thr = bbox_thr
self.gamma = gamma
......
......@@ -75,12 +75,13 @@ class PartA2RPNHead(Anchor3DHead):
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)):
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2),
init_cfg=None):
super().__init__(num_classes, in_channels, train_cfg, test_cfg,
feat_channels, use_direction_classifier,
anchor_generator, assigner_per_size, assign_per_class,
diff_rad_by_sin, dir_offset, dir_limit_offset,
bbox_coder, loss_cls, loss_bbox, loss_dir)
bbox_coder, loss_cls, loss_bbox, loss_dir, init_cfg)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds'))
def loss(self,
......
import numpy as np
import torch
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
import warnings
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
......@@ -11,7 +13,7 @@ from .anchor3d_head import Anchor3DHead
@HEADS.register_module()
class BaseShapeHead(nn.Module):
class BaseShapeHead(BaseModule):
"""Base Shape-aware Head in Shape Signature Network.
Note:
......@@ -48,8 +50,9 @@ class BaseShapeHead(nn.Module):
use_direction_classifier=True,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias=False):
super().__init__()
bias=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_cls = num_cls
self.num_base_anchors = num_base_anchors
self.use_direction_classifier = use_direction_classifier
......@@ -84,15 +87,36 @@ class BaseShapeHead(nn.Module):
if use_direction_classifier:
self.conv_dir_cls = nn.Conv2d(out_channels, num_base_anchors * 2,
1)
def init_weights(self):
"""Initialize weights."""
bias_cls = bias_init_with_prob(0.01)
# shared conv layers have already been initialized by ConvModule
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01)
if self.use_direction_classifier:
normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls)
if init_cfg is None:
if use_direction_classifier:
self.init_cfg = dict(
type='Kaiming',
layer='Conv2d',
override=[
dict(type='Normal', name='conv_reg', std=0.01),
dict(
type='Normal',
name='conv_cls',
std=0.01,
bias_prob=0.01),
dict(
type='Normal',
name='conv_dir_cls',
std=0.01,
bias_prob=0.01)
])
else:
self.init_cfg = dict(
type='Kaiming',
layer='Conv2d',
override=[
dict(type='Normal', name='conv_reg', std=0.01),
dict(
type='Normal',
name='conv_cls',
std=0.01,
bias_prob=0.01)
])
def forward(self, x):
"""Forward function for SmallHead.
......@@ -149,10 +173,21 @@ class ShapeAwareHead(Anchor3DHead):
:class:`Anchor3DHead`.
"""
def __init__(self, tasks, assign_per_class=True, **kwargs):
def __init__(self, tasks, assign_per_class=True, init_cfg=None, **kwargs):
self.tasks = tasks
self.featmap_sizes = []
super().__init__(assign_per_class=assign_per_class, **kwargs)
super().__init__(
assign_per_class=assign_per_class, init_cfg=init_cfg, **kwargs)
def init_weights(self):
if not self._is_init:
for m in self.heads:
if hasattr(m, 'init_weights'):
m.init_weights()
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '
f'been called more than once.')
def _init_layers(self):
"""Initialize neural network layers of the head."""
......@@ -175,11 +210,6 @@ class ShapeAwareHead(Anchor3DHead):
self.heads.append(build_head(branch))
cls_ptr += task['num_class']
def init_weights(self):
"""Initialize the weights of head."""
for head in self.heads:
head.init_weights()
def forward_single(self, x):
"""Forward function on a single-scale feature map.
......
......@@ -58,7 +58,8 @@ class SSD3DHead(VoteHead):
dir_res_loss=None,
size_res_loss=None,
corner_loss=None,
vote_loss=None):
vote_loss=None,
init_cfg=None):
super(SSD3DHead, self).__init__(
num_classes,
bbox_coder,
......@@ -75,7 +76,8 @@ class SSD3DHead(VoteHead):
dir_res_loss=dir_res_loss,
size_class_loss=None,
size_res_loss=size_res_loss,
semantic_loss=None)
semantic_loss=None,
init_cfg=init_cfg)
self.corner_loss = build_loss(corner_loss)
self.vote_loss = build_loss(vote_loss)
......
import numpy as np
import torch
from mmcv.runner import force_fp32
from torch import nn as nn
from mmcv.runner import BaseModule, force_fp32
from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms
......@@ -15,7 +14,7 @@ from .base_conv_bbox_head import BaseConvBboxHead
@HEADS.register_module()
class VoteHead(nn.Module):
class VoteHead(BaseModule):
r"""Bbox head of `Votenet <https://arxiv.org/abs/1904.09664>`_.
Args:
......@@ -56,8 +55,9 @@ class VoteHead(nn.Module):
size_class_loss=None,
size_res_loss=None,
semantic_loss=None,
iou_loss=None):
super(VoteHead, self).__init__()
iou_loss=None,
init_cfg=None):
super(VoteHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -92,10 +92,6 @@ class VoteHead(nn.Module):
num_cls_out_channels=self._get_cls_out_channels(),
num_reg_out_channels=self._get_reg_out_channels())
def init_weights(self):
"""Initialize weights of VoteHead."""
pass
def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
# Class numbers (k) + objectness (2)
......
......@@ -23,13 +23,14 @@ class CenterPoint(MVXTwoStageDetector):
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
super(CenterPoint,
self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained)
train_cfg, test_cfg, pretrained, init_cfg)
def extract_pts_feat(self, pts, img_feats, img_metas):
"""Extract features of points."""
......
......@@ -20,7 +20,8 @@ class DynamicVoxelNet(VoxelNet):
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
super(DynamicVoxelNet, self).__init__(
voxel_layer=voxel_layer,
voxel_encoder=voxel_encoder,
......@@ -31,7 +32,7 @@ class DynamicVoxelNet(VoxelNet):
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
init_cfg=init_cfg)
def extract_feat(self, points, img_metas):
"""Extract features from points."""
......
......@@ -19,7 +19,8 @@ class H3DNet(TwoStage3DDetector):
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
super(H3DNet, self).__init__(
backbone=backbone,
neck=neck,
......@@ -27,7 +28,8 @@ class H3DNet(TwoStage3DDetector):
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
pretrained=pretrained,
init_cfg=init_cfg)
def forward_train(self,
points,
......
import numpy as np
import torch
from torch import nn as nn
import warnings
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet3d.models.utils import MLP
......@@ -69,9 +69,10 @@ class ImVoteNet(Base3DDetector):
num_sampled_seed=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
super(ImVoteNet, self).__init__()
super(ImVoteNet, self).__init__(init_cfg=init_cfg)
# point branch
if pts_backbone is not None:
......@@ -134,11 +135,7 @@ class ImVoteNet(Base3DDetector):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize model weights."""
super(ImVoteNet, self).init_weights(pretrained)
if pretrained is None:
img_pretrained = None
pts_pretrained = None
......@@ -148,29 +145,26 @@ class ImVoteNet(Base3DDetector):
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
self.img_backbone.init_weights(pretrained=img_pretrained)
if self.with_img_neck:
if isinstance(self.img_neck, nn.Sequential):
for m in self.img_neck:
m.init_weights()
else:
self.img_neck.init_weights()
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head:
self.img_roi_head.init_weights(img_pretrained)
if self.with_img_rpn:
self.img_rpn_head.init_weights()
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone:
self.pts_backbone.init_weights(pretrained=pts_pretrained)
if self.with_pts_bbox:
self.pts_bbox_head.init_weights()
if self.with_pts_neck:
if isinstance(self.pts_neck, nn.Sequential):
for m in self.pts_neck:
m.init_weights()
else:
self.pts_neck.init_weights()
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
def freeze_img_branch_params(self):
"""Freeze all image branch parameters."""
......
......@@ -19,8 +19,9 @@ class ImVoxelNet(BaseDetector):
anchor_generator,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__()
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
......@@ -31,20 +32,6 @@ class ImVoxelNet(BaseDetector):
self.anchor_generator = build_anchor_generator(anchor_generator)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights()
self.neck_3d.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
......
import mmcv
import torch
import warnings
from mmcv.parallel import DataContainer as DC
from mmcv.runner import force_fp32
from os import path as osp
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result,
......@@ -33,8 +33,9 @@ class MVXTwoStageDetector(Base3DDetector):
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(MVXTwoStageDetector, self).__init__()
pretrained=None,
init_cfg=None):
super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg)
if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
......@@ -69,11 +70,7 @@ class MVXTwoStageDetector(Base3DDetector):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize model weights."""
super(MVXTwoStageDetector, self).init_weights(pretrained)
if pretrained is None:
img_pretrained = None
pts_pretrained = None
......@@ -83,23 +80,26 @@ class MVXTwoStageDetector(Base3DDetector):
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
self.img_backbone.init_weights(pretrained=img_pretrained)
if self.with_pts_backbone:
self.pts_backbone.init_weights(pretrained=pts_pretrained)
if self.with_img_neck:
if isinstance(self.img_neck, nn.Sequential):
for m in self.img_neck:
m.init_weights()
else:
self.img_neck.init_weights()
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head:
self.img_roi_head.init_weights(img_pretrained)
if self.with_img_rpn:
self.img_rpn_head.init_weights()
if self.with_pts_bbox:
self.pts_bbox_head.init_weights()
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated \
key, please consider using init_cfg')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
@property
def with_img_shared_head(self):
......
......@@ -24,7 +24,8 @@ class PartA2(TwoStage3DDetector):
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
init_cfg=None):
super(PartA2, self).__init__(
backbone=backbone,
neck=neck,
......@@ -33,7 +34,7 @@ class PartA2(TwoStage3DDetector):
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
init_cfg=init_cfg)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
......
from torch import nn as nn
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
......@@ -28,8 +26,9 @@ class SingleStage3DDetector(Base3DDetector):
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(SingleStage3DDetector, self).__init__()
super(SingleStage3DDetector, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
......@@ -38,19 +37,6 @@ class SingleStage3DDetector(Base3DDetector):
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize weights of detector."""
super(SingleStage3DDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck.
......
......@@ -14,10 +14,12 @@ class SSD3DNet(VoteNet):
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(SSD3DNet, self).__init__(
backbone=backbone,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
pretrained=pretrained)
......@@ -14,12 +14,14 @@ class VoteNet(SingleStage3DDetector):
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(VoteNet, self).__init__(
backbone=backbone,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=None,
pretrained=pretrained)
def forward_train(self,
......
......@@ -22,6 +22,7 @@ class VoxelNet(SingleStage3DDetector):
bbox_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(VoxelNet, self).__init__(
backbone=backbone,
......@@ -29,8 +30,8 @@ class VoxelNet(SingleStage3DDetector):
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
)
init_cfg=init_cfg,
pretrained=pretrained)
self.voxel_layer = Voxelization(**voxel_layer)
self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
self.middle_encoder = builder.build_middle_encoder(middle_encoder)
......
import torch
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn
from torch.nn import functional as F
......@@ -96,7 +97,7 @@ def point_sample(
@FUSION_LAYERS.register_module()
class PointFusion(nn.Module):
class PointFusion(BaseModule):
"""Fuse image features from multi-scale features.
Args:
......@@ -138,6 +139,7 @@ class PointFusion(nn.Module):
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
init_cfg=None,
activate_out=True,
fuse_out=False,
dropout_ratio=0,
......@@ -145,7 +147,7 @@ class PointFusion(nn.Module):
align_corners=True,
padding_mode='zeros',
lateral_conv=True):
super(PointFusion, self).__init__()
super(PointFusion, self).__init__(init_cfg=init_cfg)
if isinstance(img_levels, int):
img_levels = [img_levels]
if isinstance(img_channels, int):
......@@ -200,14 +202,11 @@ class PointFusion(nn.Module):
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(inplace=False))
self.init_weights()
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
"""Initialize the weights of modules."""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
xavier_init(m, distribution='uniform')
if init_cfg is None:
self.init_cfg = [
dict(type='Xavier', layer='Conv2d', distribution='uniform'),
dict(type='Xavier', layer='Linear', distribution='uniform')
]
def forward(self, img_feats, pts, pts_feats, img_metas):
"""Forward function.
......
import torch
from mmcv.runner import auto_fp16
from torch import nn as nn
from mmcv.runner import BaseModule, auto_fp16
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
......@@ -8,7 +7,7 @@ from ..builder import MIDDLE_ENCODERS
@MIDDLE_ENCODERS.register_module()
class SparseUNet(nn.Module):
class SparseUNet(BaseModule):
r"""SparseUNet for PartA^2.
See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
......@@ -40,8 +39,9 @@ class SparseUNet(nn.Module):
1)),
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
super().__init__()
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment