Unverified Commit 07590418 authored by xiliu8006's avatar xiliu8006 Committed by GitHub
Browse files

[Refactor]: Unified parameter initialization (#622)

* support 3dssd

* support one-stage method

* for lint

* support two_stage

* Support all methods

* remove init_cfg=[] in configs

* test

* support h3dnet

* fix lint error

* fix isort

* fix code style error

* fix imvotenet bug

* rename init_weight->init_weights

* clean comma

* fix test_apis does not init weights

* support newest mmdet and mmcv

* fix test_heads h3dnet bug

* rm *.swp

* remove the wrong code in build.yml

* fix ssn low map

* modify docs

* modified ssn init_config

* modify params in backbone pointnet2_sa_ssg

* add ssn direction init_cfg

* support segmentor

* add conv a=sqrt(5)

* Convmodule uses kaiming_init

* fix centerpointhead init bug

* add second conv2d init cfg

* add unittest to confirm the input is not be modified

* assert gt_bboxes_3d

* rm .swag

* modify docs mmdet version

* adopt fcosmono3d

* add fcos 3d original init method

* fix mmseg version

* add init cfg in fcos_mono3d.py

* merge newest master

* remove unused code

* modify focs config due to changes of resnet

* support imvoxelnet pointnet2

* modified the dependencies version

* support decode head

* fix inference bug

* modify the useless init_cfg

* fix multi_modality BC-breaking

* fix error blank

* modify docs error
parent 318499ac
......@@ -91,8 +91,8 @@ jobs:
- name: Install mmdet3d dependencies
run: |
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu101/${{matrix.torch_version}}/index.html
pip install mmdet==2.11.0
pip install mmsegmentation==0.14.0
pip install mmdet==2.14.0
pip install mmsegmentation==0.14.1
pip install -r requirements.txt
- name: Build and install
run: |
......
......@@ -16,7 +16,6 @@ model = dict(
out_channels=256,
start_level=1,
add_extra_convs=True,
extra_convs_on_inputs=False, # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
......
......@@ -99,8 +99,8 @@ model = dict(
nms_across_levels=False,
nms_pre=1000,
nms_post=1000,
max_num=1000,
nms_thr=0.7,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
img_rcnn=dict(
score_thr=0.05,
......
model = dict(
type='ImVoxelNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
......@@ -9,6 +8,7 @@ model = dict(
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
......
......@@ -4,6 +4,15 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet
## MMDetection3D 0.15.0
### Unified parameter initialization
To unify the parameter initialization in OpenMMLab projects, MMCV supports `BaseModule` that accepts `init_cfg` to allow the modules' parameters initialized in a flexible and unified manner. Now the users need to explicitly call `model.init_weights()` in the training script to initialize the model (as in [here](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/train.py#L183), previously this was handled by the detector. Please refer to PR #622 for details.
### BackgroundPointsFilter
We modified the dataset aumentation function `BackgroundPointsFilter`(in [here](https://github.com/open-mmlab/mmdetection3d/blob/mmdet3d/datasets/pipelines/transforms_3d.py#L1101)). In previous version of MMdetection3D, `BackgroundPointsFilter` changes the gt_bboxes_3d's bottom center to the gravity center. In MMDetection3D 0.15.0,
`BackgroundPointsFilter` will not change it. Please refer to PR #609 for details.
### Enhance `IndoorPatchPointSample` transform
We enhance the pipeline function `IndoorPatchPointSample` used in point cloud segmentation task by adding more choices for patch selection. Also, we plan to remove the unused parameter `sample_rate` in the future. Please modify the code as well as the config files accordingly if you use this transform.
......
......@@ -12,8 +12,8 @@ The required versions of MMCV, MMDetection and MMSegmentation for different vers
| MMDetection3D version | MMDetection version | MMSegmentation version | MMCV version |
|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
| master | mmdet>=2.10.0, <=2.11.0| mmseg==0.14.0 | mmcv-full>=1.3.1, <=1.4|
| 0.14.0 | mmdet>=2.10.0, <=2.11.0| mmseg==0.14.0 | mmcv-full>=1.3.1, <=1.4|
| master | mmdet>=2.12.0 | mmseg>=0.14.1 | mmcv-full>=1.3.2, <=1.4|
| 0.14.0 | mmdet>=2.10.0, <=2.11.0| mmseg>=0.13.0 | mmcv-full>=1.3.1, <=1.4|
| 0.13.0 | mmdet>=2.10.0, <=2.11.0| Not required | mmcv-full>=1.2.4, <=1.4|
| 0.12.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4|
| 0.11.0 | mmdet>=2.5.0, <=2.11.0 | Not required | mmcv-full>=1.2.4, <=1.4|
......
......@@ -33,9 +33,6 @@ class HardVFE(nn.Module):
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
```
#### 2. Import the module
......@@ -83,16 +80,13 @@ from ..builder import BACKBONES
@BACKBONES.register_module()
class SECOND(nn.Module):
class SECOND(BaseModule):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
```
#### 2. Import the module
......@@ -135,7 +129,7 @@ Create a new file `mmdet3d/models/necks/second_fpn.py`.
from ..builder import NECKS
@NECKS.register
class SECONDFPN(nn.Module):
class SECONDFPN(BaseModule):
def __init__(self,
in_channels=[128, 128, 256],
......@@ -144,7 +138,8 @@ class SECONDFPN(nn.Module):
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False):
use_conv_for_no_stride=False,
init_cfg=None):
pass
def forward(self, X):
......@@ -198,7 +193,7 @@ from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead
@HEADS.register_module()
class PartA2BboxHead(nn.Module):
class PartA2BboxHead(BaseModule):
"""PartA2 RoI head."""
def __init__(self,
......@@ -224,11 +219,9 @@ class PartA2BboxHead(nn.Module):
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='none',
loss_weight=1.0)):
super(PartA2BboxHead, self).__init__()
def init_weights(self):
# conv layers are already initialized by ConvModule
loss_weight=1.0),
init_cfg=None):
super(PartA2BboxHead, self).__init__(init_cfg=init_cfg)
def forward(self, seg_feats, part_feats):
......@@ -242,7 +235,7 @@ from torch import nn as nn
@HEADS.register_module()
class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
class Base3DRoIHead(BaseModule, metaclass=ABCMeta):
"""Base class for 3d RoIHeads."""
def __init__(self,
......@@ -250,7 +243,8 @@ class Base3DRoIHead(nn.Module, metaclass=ABCMeta):
mask_roi_extractor=None,
mask_head=None,
train_cfg=None,
test_cfg=None):
test_cfg=None,
init_cfg=None):
@property
def with_bbox(self):
......@@ -333,9 +327,13 @@ class PartAggregationROIHead(Base3DRoIHead):
part_roi_extractor=None,
bbox_head=None,
train_cfg=None,
test_cfg=None):
test_cfg=None,
init_cfg=None):
super(PartAggregationROIHead, self).__init__(
bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg)
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.num_classes = num_classes
assert semantic_head is not None
self.semantic_head = build_head(semantic_head)
......
......@@ -17,7 +17,7 @@ def digit_version(version_str):
return digit_version
mmcv_minimum_version = '1.3.1'
mmcv_minimum_version = '1.3.8'
mmcv_maximum_version = '1.4.0'
mmcv_version = digit_version(mmcv.__version__)
......@@ -27,8 +27,8 @@ assert (mmcv_version >= digit_version(mmcv_minimum_version)
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
mmdet_minimum_version = '2.10.0'
mmdet_maximum_version = '2.11.0'
mmdet_minimum_version = '2.14.0'
mmdet_maximum_version = '3.0.0'
mmdet_version = digit_version(mmdet.__version__)
assert (mmdet_version >= digit_version(mmdet_minimum_version)
and mmdet_version <= digit_version(mmdet_maximum_version)), \
......@@ -36,8 +36,8 @@ assert (mmdet_version >= digit_version(mmdet_minimum_version)
f'Please install mmdet>={mmdet_minimum_version}, ' \
f'<={mmdet_maximum_version}.'
mmseg_minimum_version = '0.14.0'
mmseg_maximum_version = '0.14.0'
mmseg_minimum_version = '0.14.1'
mmseg_maximum_version = '1.0.0'
mmseg_version = digit_version(mmseg.__version__)
assert (mmseg_version >= digit_version(mmseg_minimum_version)
and mmseg_version <= digit_version(mmseg_maximum_version)), \
......
import warnings
from abc import ABCMeta
from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmcv.runner import BaseModule
class BasePointNet(nn.Module, metaclass=ABCMeta):
class BasePointNet(BaseModule, metaclass=ABCMeta):
"""Base class for PointNet."""
def __init__(self):
super(BasePointNet, self).__init__()
def __init__(self, init_cfg=None, pretrained=None):
super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@staticmethod
def _split_point_feats(points):
......
import copy
import torch
import warnings
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16, load_checkpoint
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet.models import BACKBONES, build_backbone
@BACKBONES.register_module()
class MultiBackbone(nn.Module):
class MultiBackbone(BaseModule):
"""MultiBackbone with different configs.
Args:
......@@ -31,8 +32,10 @@ class MultiBackbone(nn.Module):
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs):
super().__init__()
super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
backbones_list = []
......@@ -77,14 +80,12 @@ class MultiBackbone(nn.Module):
bias=True,
inplace=True))
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet++ backbone."""
# Do not initialize the conv layers
# to follow the original implementation
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@auto_fp16()
def forward(self, points):
......
......@@ -57,8 +57,8 @@ class NoStemRegNet(RegNet):
(1, 1008, 1, 1)
"""
def __init__(self, arch, **kwargs):
super(NoStemRegNet, self).__init__(arch, **kwargs)
def __init__(self, arch, init_cfg=None, **kwargs):
super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)
def _make_stem_layer(self, in_channels, base_channels):
"""Override the original function that do not initialize a stem layer
......
......@@ -56,8 +56,9 @@ class PointNet2SAMSG(BasePointNet):
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False)):
super().__init__()
normalize_xyz=False),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels)
self.out_indices = out_indices
assert max(out_indices) < self.num_sa
......
......@@ -43,8 +43,9 @@ class PointNet2SASSG(BasePointNet):
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)):
super().__init__()
normalize_xyz=True),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels)
......
import warnings
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import load_checkpoint
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet.models import BACKBONES
@BACKBONES.register_module()
class SECOND(nn.Module):
class SECOND(BaseModule):
"""Backbone network for SECOND/PointPillars/PartA2/MVXNet.
Args:
......@@ -24,8 +25,10 @@ class SECOND(nn.Module):
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False)):
super(SECOND, self).__init__()
conv_cfg=dict(type='Conv2d', bias=False),
init_cfg=None,
pretrained=None):
super(SECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(layer_nums)
assert len(out_channels) == len(layer_nums)
......@@ -61,14 +64,14 @@ class SECOND(nn.Module):
self.blocks = nn.ModuleList(blocks)
def init_weights(self, pretrained=None):
"""Initialize weights of the 2D backbone."""
# Do not initialize the conv layers
# to follow the original implementation
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
else:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def forward(self, x):
"""Forward function.
......
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS, build)
ROI_EXTRACTORS, SHARED_HEADS)
from mmseg.models.builder import SEGMENTORS
VOXEL_ENCODERS = Registry('voxel_encoder')
MIDDLE_ENCODERS = Registry('middle_encoder')
FUSION_LAYERS = Registry('fusion_layer')
MODELS = Registry('models', parent=MMCV_MODELS)
VOXEL_ENCODERS = MODELS
MIDDLE_ENCODERS = MODELS
FUSION_LAYERS = MODELS
def build_backbone(cfg):
"""Build backbone."""
return build(cfg, BACKBONES)
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return build(cfg, NECKS)
return NECKS.build(cfg)
def build_roi_extractor(cfg):
"""Build RoI feature extractor."""
return build(cfg, ROI_EXTRACTORS)
return ROI_EXTRACTORS.build(cfg)
def build_shared_head(cfg):
"""Build shared head of detector."""
return build(cfg, SHARED_HEADS)
return SHARED_HEADS.build(cfg)
def build_head(cfg):
"""Build head."""
return build(cfg, HEADS)
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss function."""
return build(cfg, LOSSES)
return LOSSES.build(cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
......@@ -50,7 +53,8 @@ def build_detector(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return DETECTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
......@@ -63,7 +67,8 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
def build_model(cfg, train_cfg=None, test_cfg=None):
......@@ -80,14 +85,14 @@ def build_model(cfg, train_cfg=None, test_cfg=None):
def build_voxel_encoder(cfg):
"""Build voxel encoder."""
return build(cfg, VOXEL_ENCODERS)
return VOXEL_ENCODERS.build(cfg)
def build_middle_encoder(cfg):
"""Build middle level encoder."""
return build(cfg, MIDDLE_ENCODERS)
return MIDDLE_ENCODERS.build(cfg)
def build_fusion_layer(cfg):
"""Build fusion layer."""
return build(cfg, FUSION_LAYERS)
return FUSION_LAYERS.build(cfg)
from abc import ABCMeta, abstractmethod
from mmcv.cnn import normal_init
from mmcv.runner import auto_fp16, force_fp32
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from torch import nn as nn
from mmseg.models.builder import build_loss
class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
......@@ -37,8 +37,9 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=255):
super(Base3DDecodeHead, self).__init__()
ignore_index=255,
init_cfg=None):
super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
......@@ -57,6 +58,7 @@ class Base3DDecodeHead(nn.Module, metaclass=ABCMeta):
def init_weights(self):
"""Initialize weights of classification layer."""
super().init_weights()
normal_init(self.conv_seg, mean=0, std=0.01)
@auto_fp16()
......
import numpy as np
import torch
from mmcv.cnn import bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from mmdet3d.core import (PseudoSampler, box3d_multiclass_nms, limit_period,
......@@ -14,7 +13,7 @@ from .train_mixins import AnchorTrainMixin
@HEADS.register_module()
class Anchor3DHead(nn.Module, AnchorTrainMixin):
class Anchor3DHead(BaseModule, AnchorTrainMixin):
"""Anchor head for SECOND/PointPillars/MVXNet/PartA2.
Args:
......@@ -67,8 +66,9 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2)):
super().__init__()
loss_dir=dict(type='CrossEntropyLoss', loss_weight=0.2),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
......@@ -103,6 +103,14 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self._init_layers()
self._init_assigner_sampler()
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))
def _init_assigner_sampler(self):
"""Initialize the target assigner and sampler of the head."""
if self.train_cfg is None:
......@@ -129,12 +137,6 @@ class Anchor3DHead(nn.Module, AnchorTrainMixin):
self.conv_dir_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * 2, 1)
def init_weights(self):
"""Initialize the weights of head."""
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
"""Forward function on a single-scale feature map.
......
......@@ -109,8 +109,9 @@ class AnchorFreeMono3DHead(BaseMono3DDenseHead):
conv_cfg=None,
norm_cfg=None,
train_cfg=None,
test_cfg=None):
super(AnchorFreeMono3DHead, self).__init__()
test_cfg=None,
init_cfg=None):
super(AnchorFreeMono3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.cls_out_channels = num_classes
self.in_channels = in_channels
......
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet.models.builder import HEADS
@HEADS.register_module()
class BaseConvBboxHead(nn.Module):
class BaseConvBboxHead(BaseModule):
r"""More general bbox head, with shared conv layers and two optional
separated branches.
......@@ -28,9 +29,11 @@ class BaseConvBboxHead(nn.Module):
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
bias='auto',
init_cfg=None,
*args,
**kwargs):
super(BaseConvBboxHead, self).__init__(*args, **kwargs)
super(BaseConvBboxHead, self).__init__(
init_cfg=init_cfg, *args, **kwargs)
assert in_channels > 0
assert num_cls_out_channels > 0
assert num_reg_out_channels > 0
......@@ -98,10 +101,6 @@ class BaseConvBboxHead(nn.Module):
inplace=True))
return conv_layers
def init_weights(self):
# conv layers are already initialized by ConvModule
pass
def forward(self, feats):
"""Forward.
......
from abc import ABCMeta, abstractmethod
from torch import nn as nn
from mmcv.runner import BaseModule
class BaseMono3DDenseHead(nn.Module, metaclass=ABCMeta):
class BaseMono3DDenseHead(BaseModule, metaclass=ABCMeta):
"""Base class for Monocular 3D DenseHeads."""
def __init__(self):
super(BaseMono3DDenseHead, self).__init__()
def __init__(self, init_cfg=None):
super(BaseMono3DDenseHead, self).__init__(init_cfg=init_cfg)
@abstractmethod
def loss(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment