Unverified Commit 0287048a authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Enhance] Update Registry in MMDet3D (#1412)

* Update Registry in MMDet3D

* fix compose pipeline bug

* update registry

* fix some bugs

* fix comments

* fix comments
parent e013bab5
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector from .base import Base3DDetector
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp from os import path as osp
import mmcv import mmcv
...@@ -8,8 +9,8 @@ from mmcv.parallel import DataContainer as DC ...@@ -8,8 +9,8 @@ from mmcv.parallel import DataContainer as DC
from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result, from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result,
show_multi_modality_result) show_multi_modality_result)
from mmdet.models.builder import DETECTORS from mmdet.models.detectors import SingleStageDetector
from mmdet.models.detectors.single_stage import SingleStageDetector from ..builder import DETECTORS, build_backbone, build_head, build_neck
@DETECTORS.register_module() @DETECTORS.register_module()
...@@ -20,6 +21,28 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -20,6 +21,28 @@ class SingleStageMono3DDetector(SingleStageDetector):
output features of the backbone+neck. output features of the backbone+neck.
""" """
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feats(self, imgs): def extract_feats(self, imgs):
"""Directly extract features from the backbone+neck.""" """Directly extract features from the backbone+neck."""
assert isinstance(imgs, list) assert isinstance(imgs, list)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.builder import DETECTORS from ..builder import DETECTORS
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS from ..builder import DETECTORS
from .votenet import VoteNet from .votenet import VoteNet
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS, TwoStageDetector import warnings
from mmdet.models import TwoStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector from .base import Base3DDetector
...@@ -12,5 +15,36 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector): ...@@ -12,5 +15,36 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
two-stage 3D detectors. two-stage 3D detectors.
""" """
def __init__(self, **kwargs): def __init__(self,
super(TwoStage3DDetector, self).__init__(**kwargs) backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(TwoStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
roi_head.pretrained = pretrained
self.roi_head = build_head(roi_head)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import torch import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
......
...@@ -5,8 +5,8 @@ from mmcv.runner import force_fp32 ...@@ -5,8 +5,8 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .. import builder from .. import builder
from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D from ...core.bbox import AxisAlignedBboxOverlaps3D
from ..builder import LOSSES
@weighted_loss @weighted_loss
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet.models.builder import LOSSES from ..builder import LOSSES
def chamfer_distance(src, def chamfer_distance(src,
......
...@@ -3,8 +3,8 @@ import torch ...@@ -3,8 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss @weighted_loss
......
...@@ -3,8 +3,8 @@ import torch ...@@ -3,8 +3,8 @@ import torch
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss from mmdet.models.losses.utils import weight_reduce_loss
from ..builder import LOSSES
def weight_correlation(conv): def weight_correlation(conv):
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import torch import torch
from torch import nn as nn from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss @weighted_loss
......
...@@ -6,7 +6,7 @@ from mmcv.cnn import ConvModule, build_conv_layer ...@@ -6,7 +6,7 @@ from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet.models.builder import NECKS from ..builder import NECKS
def fill_up_weights(up): def fill_up_weights(up):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from torch import nn from torch import nn
from mmdet.models import NECKS from ..builder import NECKS
@NECKS.register_module() @NECKS.register_module()
......
...@@ -3,7 +3,7 @@ from mmcv.runner import BaseModule ...@@ -3,7 +3,7 @@ from mmcv.runner import BaseModule
from torch import nn as nn from torch import nn as nn
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
from mmdet.models import NECKS from ..builder import NECKS
@NECKS.register_module() @NECKS.register_module()
......
...@@ -5,7 +5,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer ...@@ -5,7 +5,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
from mmdet.models import NECKS from ..builder import NECKS
@NECKS.register_module() @NECKS.register_module()
......
...@@ -7,11 +7,10 @@ from torch.nn import functional as F ...@@ -7,11 +7,10 @@ from torch.nn import functional as F
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
......
...@@ -9,10 +9,9 @@ from torch import nn as nn ...@@ -9,10 +9,9 @@ from torch import nn as nn
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops import make_sparse_convmodule from mmdet3d.ops import make_sparse_convmodule
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
......
...@@ -9,10 +9,9 @@ from torch import nn as nn ...@@ -9,10 +9,9 @@ from torch import nn as nn
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr) rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module() @HEADS.register_module()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result from mmdet3d.core.bbox import bbox3d2result
from mmdet.models import HEADS from ..builder import HEADS, build_head
from ..builder import build_head
from .base_3droi_head import Base3DRoIHead from .base_3droi_head import Base3DRoIHead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment