"backends/client/src/lib.rs" did not exist on "017a2a8c2f72c4c30b01eab0da53c00c9c1f7057"
Unverified Commit 0287048a authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Enhance] Update Registry in MMDet3D (#1412)

* Update Registry in MMDet3D

* fix compose pipeline bug

* update registry

* fix some bugs

* fix comments

* fix comments
parent e013bab5
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from os import path as osp
import mmcv
......@@ -8,8 +9,8 @@ from mmcv.parallel import DataContainer as DC
from mmdet3d.core import (CameraInstance3DBoxes, bbox3d2result,
show_multi_modality_result)
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector
from mmdet.models.detectors import SingleStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
@DETECTORS.register_module()
......@@ -20,6 +21,28 @@ class SingleStageMono3DDetector(SingleStageDetector):
output features of the backbone+neck.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(SingleStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feats(self, imgs):
"""Directly extract features from the backbone+neck."""
assert isinstance(imgs, list)
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.builder import DETECTORS
from ..builder import DETECTORS
from .single_stage_mono3d import SingleStageMono3DDetector
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS
from ..builder import DETECTORS
from .votenet import VoteNet
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models import DETECTORS, TwoStageDetector
import warnings
from mmdet.models import TwoStageDetector
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import Base3DDetector
......@@ -12,5 +15,36 @@ class TwoStage3DDetector(Base3DDetector, TwoStageDetector):
two-stage 3D detectors.
"""
def __init__(self, **kwargs):
super(TwoStage3DDetector, self).__init__(**kwargs)
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(TwoStageDetector, self).__init__(init_cfg)
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = rpn_head.copy()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=test_cfg.rcnn)
roi_head.pretrained = pretrained
self.roi_head = build_head(roi_head)
......@@ -2,7 +2,7 @@
import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector
......
......@@ -5,8 +5,8 @@ from mmcv.runner import force_fp32
from torch.nn import functional as F
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .. import builder
from ..builder import DETECTORS
from .single_stage import SingleStage3DDetector
......
......@@ -2,9 +2,9 @@
import torch
from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ...core.bbox import AxisAlignedBboxOverlaps3D
from ..builder import LOSSES
@weighted_loss
......
......@@ -3,7 +3,7 @@ import torch
from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet.models.builder import LOSSES
from ..builder import LOSSES
def chamfer_distance(src,
......
......@@ -3,8 +3,8 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss
......
......@@ -3,8 +3,8 @@ import torch
from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
from ..builder import LOSSES
def weight_correlation(conv):
......
......@@ -2,8 +2,8 @@
import torch
from torch import nn as nn
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss
from ..builder import LOSSES
@weighted_loss
......
......@@ -6,7 +6,7 @@ from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet.models.builder import NECKS
from ..builder import NECKS
def fill_up_weights(up):
......
......@@ -2,7 +2,7 @@
from mmcv.cnn import ConvModule
from torch import nn
from mmdet.models import NECKS
from ..builder import NECKS
@NECKS.register_module()
......
......@@ -3,7 +3,7 @@ from mmcv.runner import BaseModule
from torch import nn as nn
from mmdet3d.ops import PointFPModule
from mmdet.models import NECKS
from ..builder import NECKS
@NECKS.register_module()
......
......@@ -5,7 +5,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet.models import NECKS
from ..builder import NECKS
@NECKS.register_module()
......
......@@ -7,11 +7,10 @@ from torch.nn import functional as F
from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.models.losses import chamfer_distance
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module()
......
......@@ -9,10 +9,9 @@ from torch import nn as nn
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops import make_sparse_convmodule
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module()
......
......@@ -9,10 +9,9 @@ from torch import nn as nn
from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev, nms_normal_bev
from mmdet3d.models.builder import build_loss
from mmdet3d.models.builder import HEADS, build_loss
from mmdet3d.ops import build_sa_module
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.models import HEADS
@HEADS.register_module()
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet3d.core.bbox import bbox3d2result
from mmdet.models import HEADS
from ..builder import build_head
from ..builder import HEADS, build_head
from .base_3droi_head import Base3DRoIHead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment