Unverified Commit 07590418 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Refactor]: Unified parameter initialization (#622)

* support 3dssd

* support one-stage method

* for lint

* support two_stage

* Support all methods

* remove init_cfg=[] in configs

* test

* support h3dnet

* fix lint error

* fix isort

* fix code style error

* fix imvotenet bug

* rename init_weight->init_weights

* clean comma

* fix test_apis does not init weights

* support newest mmdet and mmcv

* fix test_heads h3dnet bug

* rm *.swp

* remove the wrong code in build.yml

* fix ssn low map

* modify docs

* modified ssn init_config

* modify params in backbone pointnet2_sa_ssg

* add ssn direction init_cfg

* support segmentor

* add conv a=sqrt(5)

* Convmodule uses kaiming_init

* fix centerpointhead init bug

* add second conv2d init cfg

* add unittest to confirm the input is not be modified

* assert gt_bboxes_3d

* rm .swag

* modify docs mmdet version

* adopt fcosmono3d

* add fcos 3d original init method

* fix mmseg version

* add init cfg in fcos_mono3d.py

* merge newest master

* remove unused code

* modify focs config due to changes of resnet

* support imvoxelnet pointnet2

* modified the dependencies version

* support decode head

* fix inference bug

* modify the useless init_cfg

* fix multi_modality BC-breaking

* fix error blank

* modify docs error
parent 318499ac
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
constant_init, is_norm, kaiming_init) from mmcv.runner import BaseModule, auto_fp16
from mmcv.runner import auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet.models import NECKS from mmdet.models import NECKS
@NECKS.register_module() @NECKS.register_module()
class SECONDFPN(nn.Module): class SECONDFPN(BaseModule):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet. """FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args: Args:
...@@ -30,10 +29,11 @@ class SECONDFPN(nn.Module): ...@@ -30,10 +29,11 @@ class SECONDFPN(nn.Module):
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False), upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False), conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False): use_conv_for_no_stride=False,
init_cfg=None):
# if for GroupNorm, # if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__() super(SECONDFPN, self).__init__(init_cfg=init_cfg)
assert len(out_channels) == len(upsample_strides) == len(in_channels) assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -64,13 +64,11 @@ class SECONDFPN(nn.Module): ...@@ -64,13 +64,11 @@ class SECONDFPN(nn.Module):
deblocks.append(deblock) deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks) self.deblocks = nn.ModuleList(deblocks)
def init_weights(self): if init_cfg is None:
"""Initialize weights of FPN.""" self.init_cfg = [
for m in self.modules(): dict(type='Kaiming', layer='ConvTranspose2d'),
if isinstance(m, nn.Conv2d): dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
kaiming_init(m) ]
elif is_norm(m):
constant_init(m, 1)
@auto_fp16() @auto_fp16()
def forward(self, x): def forward(self, x):
......
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from torch import nn as nn from mmcv.runner import BaseModule
class Base3DRoIHead(nn.Module, metaclass=ABCMeta): class Base3DRoIHead(BaseModule, metaclass=ABCMeta):
"""Base class for 3d RoIHeads.""" """Base class for 3d RoIHeads."""
def __init__(self, def __init__(self,
...@@ -10,8 +10,10 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -10,8 +10,10 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
mask_roi_extractor=None, mask_roi_extractor=None,
mask_head=None, mask_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
super(Base3DRoIHead, self).__init__() pretrained=None,
init_cfg=None):
super(Base3DRoIHead, self).__init__(init_cfg=init_cfg)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -33,11 +35,6 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta): ...@@ -33,11 +35,6 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
"""bool: whether the RoIHead has mask head""" """bool: whether the RoIHead has mask head"""
return hasattr(self, 'mask_head') and self.mask_head is not None return hasattr(self, 'mask_head') and self.mask_head is not None
@abstractmethod
def init_weights(self, pretrained):
"""Initialize the module with pre-trained weights."""
pass
@abstractmethod @abstractmethod
def init_bbox_head(self): def init_bbox_head(self):
"""Initialize the box head.""" """Initialize the box head."""
......
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -13,7 +14,7 @@ from mmdet.models import HEADS ...@@ -13,7 +14,7 @@ from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
class H3DBboxHead(nn.Module): class H3DBboxHead(BaseModule):
r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_. r"""Bbox head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
Args: Args:
...@@ -80,8 +81,9 @@ class H3DBboxHead(nn.Module): ...@@ -80,8 +81,9 @@ class H3DBboxHead(nn.Module):
cues_objectness_loss=None, cues_objectness_loss=None,
cues_semantic_loss=None, cues_semantic_loss=None,
proposal_objectness_loss=None, proposal_objectness_loss=None,
primitive_center_loss=None): primitive_center_loss=None,
super(H3DBboxHead, self).__init__() init_cfg=None):
super(H3DBboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -198,15 +200,6 @@ class H3DBboxHead(nn.Module): ...@@ -198,15 +200,6 @@ class H3DBboxHead(nn.Module):
bbox_coder['num_sizes'] * 4 + self.num_classes) bbox_coder['num_sizes'] * 4 + self.num_classes)
self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1)) self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1))
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
pass
def forward(self, feats_dict, sample_mod): def forward(self, feats_dict, sample_mod):
"""Forward pass. """Forward pass.
......
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule, normal_init, xavier_init from mmcv.cnn import ConvModule, normal_init
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
...@@ -14,7 +15,7 @@ from mmdet.models import HEADS ...@@ -14,7 +15,7 @@ from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
class PartA2BboxHead(nn.Module): class PartA2BboxHead(BaseModule):
"""PartA2 RoI head. """PartA2 RoI head.
Args: Args:
...@@ -67,8 +68,9 @@ class PartA2BboxHead(nn.Module): ...@@ -67,8 +68,9 @@ class PartA2BboxHead(nn.Module):
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
reduction='none', reduction='none',
loss_weight=1.0)): loss_weight=1.0),
super(PartA2BboxHead, self).__init__() init_cfg=None):
super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.with_corner_loss = with_corner_loss self.with_corner_loss = with_corner_loss
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
...@@ -220,14 +222,14 @@ class PartA2BboxHead(nn.Module): ...@@ -220,14 +222,14 @@ class PartA2BboxHead(nn.Module):
self.conv_reg = nn.Sequential(*reg_layers) self.conv_reg = nn.Sequential(*reg_layers)
self.init_weights() if init_cfg is None:
self.init_cfg = dict(
type='Xavier',
layer=['Conv2d', 'Conv1d'],
distribution='uniform')
def init_weights(self): def init_weights(self):
"""Initialize weights of the bbox head.""" super().init_weights()
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
xavier_init(m, distribution='uniform')
normal_init(self.conv_reg[-1].conv, mean=0, std=0.001) normal_init(self.conv_reg[-1].conv, mean=0, std=0.001)
def forward(self, seg_feats, part_feats): def forward(self, seg_feats, part_feats):
......
...@@ -19,20 +19,21 @@ class H3DRoIHead(Base3DRoIHead): ...@@ -19,20 +19,21 @@ class H3DRoIHead(Base3DRoIHead):
primitive_list, primitive_list,
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
pretrained=None,
init_cfg=None):
super(H3DRoIHead, self).__init__( super(H3DRoIHead, self).__init__(
bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg) bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
# Primitive module # Primitive module
assert len(primitive_list) == 3 assert len(primitive_list) == 3
self.primitive_z = build_head(primitive_list[0]) self.primitive_z = build_head(primitive_list[0])
self.primitive_xy = build_head(primitive_list[1]) self.primitive_xy = build_head(primitive_list[1])
self.primitive_line = build_head(primitive_list[2]) self.primitive_line = build_head(primitive_list[2])
def init_weights(self, pretrained):
"""Initialize weights, skip since ``H3DROIHead`` does not need to
initialize weights."""
pass
def init_mask_head(self): def init_mask_head(self):
"""Initialize mask head, skip since ``H3DROIHead`` does not have """Initialize mask head, skip since ``H3DROIHead`` does not have
one.""" one."""
......
import torch import torch
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -9,7 +10,7 @@ from mmdet.models import HEADS ...@@ -9,7 +10,7 @@ from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
class PointwiseSemanticHead(nn.Module): class PointwiseSemanticHead(BaseModule):
"""Semantic segmentation head for point-wise segmentation. """Semantic segmentation head for point-wise segmentation.
Predict point-wise segmentation and part regression results for PartA2. Predict point-wise segmentation and part regression results for PartA2.
...@@ -28,6 +29,7 @@ class PointwiseSemanticHead(nn.Module): ...@@ -28,6 +29,7 @@ class PointwiseSemanticHead(nn.Module):
num_classes=3, num_classes=3,
extra_width=0.2, extra_width=0.2,
seg_score_thr=0.3, seg_score_thr=0.3,
init_cfg=None,
loss_seg=dict( loss_seg=dict(
type='FocalLoss', type='FocalLoss',
use_sigmoid=True, use_sigmoid=True,
...@@ -39,7 +41,7 @@ class PointwiseSemanticHead(nn.Module): ...@@ -39,7 +41,7 @@ class PointwiseSemanticHead(nn.Module):
type='CrossEntropyLoss', type='CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
loss_weight=1.0)): loss_weight=1.0)):
super(PointwiseSemanticHead, self).__init__() super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg)
self.extra_width = extra_width self.extra_width = extra_width
self.num_classes = num_classes self.num_classes = num_classes
self.seg_score_thr = seg_score_thr self.seg_score_thr = seg_score_thr
......
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -11,7 +12,7 @@ from mmdet.models import HEADS ...@@ -11,7 +12,7 @@ from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
class PrimitiveHead(nn.Module): class PrimitiveHead(BaseModule):
r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_. r"""Primitive head of `H3DNet <https://arxiv.org/abs/2006.05682>`_.
Args: Args:
...@@ -52,8 +53,9 @@ class PrimitiveHead(nn.Module): ...@@ -52,8 +53,9 @@ class PrimitiveHead(nn.Module):
objectness_loss=None, objectness_loss=None,
center_loss=None, center_loss=None,
semantic_reg_loss=None, semantic_reg_loss=None,
semantic_cls_loss=None): semantic_cls_loss=None,
super(PrimitiveHead, self).__init__() init_cfg=None):
super(PrimitiveHead, self).__init__(init_cfg=init_cfg)
assert primitive_mode in ['z', 'xy', 'line'] assert primitive_mode in ['z', 'xy', 'line']
# The dimension of primitive semantic information. # The dimension of primitive semantic information.
self.num_dims = num_dims self.num_dims = num_dims
...@@ -110,10 +112,6 @@ class PrimitiveHead(nn.Module): ...@@ -110,10 +112,6 @@ class PrimitiveHead(nn.Module):
self.conv_pred.add_module('conv_out', self.conv_pred.add_module('conv_out',
nn.Conv1d(prev_channel, conv_out_channel, 1)) nn.Conv1d(prev_channel, conv_out_channel, 1))
def init_weights(self):
"""Initialize weights of VoteHead."""
pass
def forward(self, feats_dict, sample_mod): def forward(self, feats_dict, sample_mod):
"""Forward pass. """Forward pass.
......
import warnings
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import AssignResult from mmdet3d.core import AssignResult
...@@ -29,9 +30,14 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -29,9 +30,14 @@ class PartAggregationROIHead(Base3DRoIHead):
part_roi_extractor=None, part_roi_extractor=None,
bbox_head=None, bbox_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None): test_cfg=None,
pretrained=None,
init_cfg=None):
super(PartAggregationROIHead, self).__init__( super(PartAggregationROIHead, self).__init__(
bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg) bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
assert semantic_head is not None assert semantic_head is not None
self.semantic_head = build_head(semantic_head) self.semantic_head = build_head(semantic_head)
...@@ -43,10 +49,12 @@ class PartAggregationROIHead(Base3DRoIHead): ...@@ -43,10 +49,12 @@ class PartAggregationROIHead(Base3DRoIHead):
self.init_assigner_sampler() self.init_assigner_sampler()
def init_weights(self, pretrained): assert not (init_cfg and pretrained), \
"""Initialize weights, skip since ``PartAggregationROIHead`` does not 'init_cfg and pretrained cannot be setting at the same time'
need to initialize weights.""" if isinstance(pretrained, str):
pass warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
def init_mask_head(self): def init_mask_head(self):
"""Initialize mask head, skip since ``PartAggregationROIHead`` does not """Initialize mask head, skip since ``PartAggregationROIHead`` does not
......
import torch import torch
from torch import nn as nn from mmcv.runner import BaseModule
from mmdet3d import ops from mmdet3d import ops
from mmdet.models.builder import ROI_EXTRACTORS from mmdet.models.builder import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module() @ROI_EXTRACTORS.register_module()
class Single3DRoIAwareExtractor(nn.Module): class Single3DRoIAwareExtractor(BaseModule):
"""Point-wise roi-aware Extractor. """Point-wise roi-aware Extractor.
Extract Point-wise roi features. Extract Point-wise roi features.
...@@ -15,8 +15,8 @@ class Single3DRoIAwareExtractor(nn.Module): ...@@ -15,8 +15,8 @@ class Single3DRoIAwareExtractor(nn.Module):
roi_layer (dict): The config of roi layer. roi_layer (dict): The config of roi layer.
""" """
def __init__(self, roi_layer=None): def __init__(self, roi_layer=None, init_cfg=None):
super(Single3DRoIAwareExtractor, self).__init__() super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg)
self.roi_layer = self.build_roi_layers(roi_layer) self.roi_layer = self.build_roi_layers(roi_layer)
def build_roi_layers(self, layer_cfg): def build_roi_layers(self, layer_cfg):
......
...@@ -25,8 +25,9 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -25,8 +25,9 @@ class EncoderDecoder3D(Base3DSegmentor):
auxiliary_head=None, auxiliary_head=None,
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
pretrained=None): pretrained=None,
super(EncoderDecoder3D, self).__init__() init_cfg=None):
super(EncoderDecoder3D, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone) self.backbone = build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = build_neck(neck) self.neck = build_neck(neck)
...@@ -35,9 +36,6 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -35,9 +36,6 @@ class EncoderDecoder3D(Base3DSegmentor):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
assert self.with_decode_head, \ assert self.with_decode_head, \
'3D EncoderDecoder Segmentor should have a decode_head' '3D EncoderDecoder Segmentor should have a decode_head'
...@@ -56,24 +54,6 @@ class EncoderDecoder3D(Base3DSegmentor): ...@@ -56,24 +54,6 @@ class EncoderDecoder3D(Base3DSegmentor):
else: else:
self.auxiliary_head = build_head(auxiliary_head) self.auxiliary_head = build_head(auxiliary_head)
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone and heads.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(EncoderDecoder3D, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.decode_head.init_weights()
if self.with_auxiliary_head:
if isinstance(self.auxiliary_head, nn.ModuleList):
for aux_head in self.auxiliary_head:
aux_head.init_weights()
else:
self.auxiliary_head.init_weights()
def extract_feat(self, points): def extract_feat(self, points):
"""Extract features from points.""" """Extract features from points."""
x = self.backbone(points) x = self.backbone(points)
......
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
class MLP(nn.Module): class MLP(BaseModule):
"""A simple MLP module. """A simple MLP module.
Pass features (B, C, N) through an MLP. Pass features (B, C, N) through an MLP.
...@@ -25,8 +26,9 @@ class MLP(nn.Module): ...@@ -25,8 +26,9 @@ class MLP(nn.Module):
conv_channels=(256, 256), conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'), norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU')): act_cfg=dict(type='ReLU'),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.mlp = nn.Sequential() self.mlp = nn.Sequential()
prev_channels = in_channel prev_channels = in_channel
for i, conv_channel in enumerate(conv_channels): for i, conv_channel in enumerate(conv_channels):
......
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn from torch import nn as nn
from typing import List from typing import List
from mmdet3d.ops import three_interpolate, three_nn from mmdet3d.ops import three_interpolate, three_nn
class PointFPModule(nn.Module): class PointFPModule(BaseModule):
"""Point feature propagation module used in PointNets. """Point feature propagation module used in PointNets.
Propagate the features from one set to another. Propagate the features from one set to another.
...@@ -20,8 +20,9 @@ class PointFPModule(nn.Module): ...@@ -20,8 +20,9 @@ class PointFPModule(nn.Module):
def __init__(self, def __init__(self,
mlp_channels: List[int], mlp_channels: List[int],
norm_cfg: dict = dict(type='BN2d')): norm_cfg: dict = dict(type='BN2d'),
super().__init__() init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.fp16_enabled = False self.fp16_enabled = False
self.mlps = nn.Sequential() self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1): for i in range(len(mlp_channels) - 1):
......
...@@ -244,10 +244,13 @@ def test_show_result_meshlab(): ...@@ -244,10 +244,13 @@ def test_show_result_meshlab():
def test_inference_detector(): def test_inference_detector():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
pcd = 'tests/data/kitti/training/velodyne_reduced/000000.bin' pcd = 'tests/data/kitti/training/velodyne_reduced/000000.bin'
detector_cfg = 'configs/pointpillars/hv_pointpillars_secfpn_' \ detector_cfg = 'configs/pointpillars/hv_pointpillars_secfpn_' \
'6x8_160e_kitti-3d-3class.py' '6x8_160e_kitti-3d-3class.py'
detector = init_model(detector_cfg, device='cpu') detector = init_model(detector_cfg, device='cuda:0')
results = inference_detector(detector, pcd) results = inference_detector(detector, pcd)
bboxes_3d = results[0][0]['boxes_3d'] bboxes_3d = results[0][0]['boxes_3d']
scores_3d = results[0][0]['scores_3d'] scores_3d = results[0][0]['scores_3d']
......
...@@ -180,6 +180,7 @@ def main(): ...@@ -180,6 +180,7 @@ def main():
cfg.model, cfg.model,
train_cfg=cfg.get('train_cfg'), train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')) test_cfg=cfg.get('test_cfg'))
model.init_weights()
logger.info(f'Model:\n{model}') logger.info(f'Model:\n{model}')
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment