Commit 343267ed authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'cnn-bricks' into 'master'

Migrate to MMDet V2.0-stable using mmcv cnn bricks

See merge request open-mmlab/mmdet.3d!14
parents 99db60dd 7ed412b7
...@@ -374,7 +374,7 @@ from .coco import CocoDataset ...@@ -374,7 +374,7 @@ from .coco import CocoDataset
from .registry import DATASETS from .registry import DATASETS
@DATASETS.register_module @DATASETS.register_module()
class MyDataset(CocoDataset): class MyDataset(CocoDataset):
CLASSES = ('a', 'b', 'c', 'd', 'e') CLASSES = ('a', 'b', 'c', 'd', 'e')
...@@ -444,7 +444,7 @@ from .registry import OPTIMIZERS ...@@ -444,7 +444,7 @@ from .registry import OPTIMIZERS
from torch.optim import Optimizer from torch.optim import Optimizer
@OPTIMIZERS.register_module @OPTIMIZERS.register_module()
class MyOptimizer(Optimizer): class MyOptimizer(Optimizer):
``` ```
...@@ -476,7 +476,7 @@ import torch.nn as nn ...@@ -476,7 +476,7 @@ import torch.nn as nn
from ..registry import BACKBONES from ..registry import BACKBONES
@BACKBONES.register_module @BACKBONES.register_module()
class MobileNet(nn.Module): class MobileNet(nn.Module):
def __init__(self, arg1, arg2): def __init__(self, arg1, arg2):
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from mmdet.core.anchor import ANCHOR_GENERATORS from mmdet.core.anchor import ANCHOR_GENERATORS
@ANCHOR_GENERATORS.register_module @ANCHOR_GENERATORS.register_module()
class Anchor3DRangeGenerator(object): class Anchor3DRangeGenerator(object):
"""3D Anchor Generator by range """3D Anchor Generator by range
...@@ -183,7 +183,7 @@ class Anchor3DRangeGenerator(object): ...@@ -183,7 +183,7 @@ class Anchor3DRangeGenerator(object):
return ret return ret
@ANCHOR_GENERATORS.register_module @ANCHOR_GENERATORS.register_module()
class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator): class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
"""Aligned 3D Anchor Generator by range """Aligned 3D Anchor Generator by range
......
...@@ -4,7 +4,7 @@ from mmdet.core.bbox import BaseBBoxCoder ...@@ -4,7 +4,7 @@ from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS from mmdet.core.bbox.builder import BBOX_CODERS
@BBOX_CODERS.register_module @BBOX_CODERS.register_module()
class DeltaXYZWLHRBBoxCoder(BaseBBoxCoder): class DeltaXYZWLHRBBoxCoder(BaseBBoxCoder):
def __init__(self, code_size=7): def __init__(self, code_size=7):
......
...@@ -6,7 +6,7 @@ from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS ...@@ -6,7 +6,7 @@ from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS
from .. import box_torch_ops from .. import box_torch_ops
@IOU_CALCULATORS.register_module @IOU_CALCULATORS.register_module()
class BboxOverlapsNearest3D(object): class BboxOverlapsNearest3D(object):
"""Nearest 3D IoU Calculator""" """Nearest 3D IoU Calculator"""
...@@ -20,7 +20,7 @@ class BboxOverlapsNearest3D(object): ...@@ -20,7 +20,7 @@ class BboxOverlapsNearest3D(object):
return repr_str return repr_str
@IOU_CALCULATORS.register_module @IOU_CALCULATORS.register_module()
class BboxOverlaps3D(object): class BboxOverlaps3D(object):
"""3D IoU Calculator""" """3D IoU Calculator"""
......
...@@ -5,7 +5,7 @@ from mmdet.utils import get_root_logger ...@@ -5,7 +5,7 @@ from mmdet.utils import get_root_logger
from .cocktail_optimizer import CocktailOptimizer from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module @OPTIMIZER_BUILDERS.register_module()
class CocktailOptimizerConstructor(object): class CocktailOptimizerConstructor(object):
"""Special constructor for cocktail optimizers. """Special constructor for cocktail optimizers.
......
...@@ -3,7 +3,7 @@ from torch.optim import Optimizer ...@@ -3,7 +3,7 @@ from torch.optim import Optimizer
from mmdet.core.optimizer import OPTIMIZERS from mmdet.core.optimizer import OPTIMIZERS
@OPTIMIZERS.register_module @OPTIMIZERS.register_module()
class CocktailOptimizer(Optimizer): class CocktailOptimizer(Optimizer):
"""Cocktail Optimizer that contains multiple optimizers """Cocktail Optimizer that contains multiple optimizers
......
...@@ -7,7 +7,7 @@ from mmdet.datasets import DATASETS ...@@ -7,7 +7,7 @@ from mmdet.datasets import DATASETS
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa # Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module @DATASETS.register_module()
class RepeatFactorDataset(object): class RepeatFactorDataset(object):
"""A wrapper of repeated dataset with repeat factor. """A wrapper of repeated dataset with repeat factor.
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from mmdet.datasets import DATASETS, CustomDataset from mmdet.datasets import DATASETS, CustomDataset
@DATASETS.register_module @DATASETS.register_module()
class Kitti2DDataset(CustomDataset): class Kitti2DDataset(CustomDataset):
CLASSES = ('car', 'pedestrian', 'cyclist') CLASSES = ('car', 'pedestrian', 'cyclist')
......
...@@ -15,7 +15,7 @@ from .pipelines import Compose ...@@ -15,7 +15,7 @@ from .pipelines import Compose
from .utils import remove_dontcare from .utils import remove_dontcare
@DATASETS.register_module @DATASETS.register_module()
class KittiDataset(torch_data.Dataset): class KittiDataset(torch_data.Dataset):
CLASSES = ('car', 'pedestrian', 'cyclist') CLASSES = ('car', 'pedestrian', 'cyclist')
......
...@@ -13,7 +13,7 @@ from ..core.bbox import box_np_ops ...@@ -13,7 +13,7 @@ from ..core.bbox import box_np_ops
from .pipelines import Compose from .pipelines import Compose
@DATASETS.register_module @DATASETS.register_module()
class NuScenesDataset(torch_data.Dataset): class NuScenesDataset(torch_data.Dataset):
NumPointFeatures = 4 # xyz, timestamp. set 4 to use kitti pretrain NumPointFeatures = 4 # xyz, timestamp. set 4 to use kitti pretrain
NameMapping = { NameMapping = {
......
...@@ -52,7 +52,7 @@ class BatchSampler: ...@@ -52,7 +52,7 @@ class BatchSampler:
return [self._sampled_list[i] for i in indices] return [self._sampled_list[i] for i in indices]
@OBJECTSAMPLERS.register_module @OBJECTSAMPLERS.register_module()
class DataBaseSampler(object): class DataBaseSampler(object):
def __init__(self, info_path, root_path, rate, prepare, object_rot_range, def __init__(self, info_path, root_path, rate, prepare, object_rot_range,
...@@ -255,7 +255,7 @@ class DataBaseSampler(object): ...@@ -255,7 +255,7 @@ class DataBaseSampler(object):
return valid_samples return valid_samples
@OBJECTSAMPLERS.register_module @OBJECTSAMPLERS.register_module()
class MMDataBaseSampler(DataBaseSampler): class MMDataBaseSampler(DataBaseSampler):
def __init__(self, def __init__(self,
......
...@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import to_tensor ...@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import to_tensor
PIPELINES._module_dict.pop('DefaultFormatBundle') PIPELINES._module_dict.pop('DefaultFormatBundle')
@PIPELINES.register_module @PIPELINES.register_module()
class DefaultFormatBundle(object): class DefaultFormatBundle(object):
"""Default formatting bundle. """Default formatting bundle.
...@@ -59,7 +59,7 @@ class DefaultFormatBundle(object): ...@@ -59,7 +59,7 @@ class DefaultFormatBundle(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module @PIPELINES.register_module()
class Collect3D(object): class Collect3D(object):
def __init__(self, def __init__(self,
...@@ -90,7 +90,7 @@ class Collect3D(object): ...@@ -90,7 +90,7 @@ class Collect3D(object):
self.keys, self.meta_keys) self.keys, self.meta_keys)
@PIPELINES.register_module @PIPELINES.register_module()
class DefaultFormatBundle3D(DefaultFormatBundle): class DefaultFormatBundle3D(DefaultFormatBundle):
"""Default formatting bundle. """Default formatting bundle.
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module @PIPELINES.register_module()
class LoadPointsFromFile(object): class LoadPointsFromFile(object):
def __init__(self, points_dim=4, with_reflectivity=True): def __init__(self, points_dim=4, with_reflectivity=True):
...@@ -31,7 +31,7 @@ class LoadPointsFromFile(object): ...@@ -31,7 +31,7 @@ class LoadPointsFromFile(object):
return repr_str return repr_str
@PIPELINES.register_module @PIPELINES.register_module()
class LoadMultiViewImageFromFiles(object): class LoadMultiViewImageFromFiles(object):
""" Load multi channel images from a list of separate channel files. """ Load multi channel images from a list of separate channel files.
Expects results['filename'] to be a list of filenames Expects results['filename'] to be a list of filenames
......
...@@ -8,7 +8,7 @@ from ..registry import OBJECTSAMPLERS ...@@ -8,7 +8,7 @@ from ..registry import OBJECTSAMPLERS
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
@PIPELINES.register_module @PIPELINES.register_module()
class RandomFlip3D(RandomFlip): class RandomFlip3D(RandomFlip):
"""Flip the points & bbox. """Flip the points & bbox.
...@@ -51,7 +51,7 @@ class RandomFlip3D(RandomFlip): ...@@ -51,7 +51,7 @@ class RandomFlip3D(RandomFlip):
return input_dict return input_dict
@PIPELINES.register_module @PIPELINES.register_module()
class ObjectSample(object): class ObjectSample(object):
def __init__(self, db_sampler, sample_2d=False): def __init__(self, db_sampler, sample_2d=False):
...@@ -128,7 +128,7 @@ class ObjectSample(object): ...@@ -128,7 +128,7 @@ class ObjectSample(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module @PIPELINES.register_module()
class ObjectNoise(object): class ObjectNoise(object):
def __init__(self, def __init__(self,
...@@ -167,7 +167,7 @@ class ObjectNoise(object): ...@@ -167,7 +167,7 @@ class ObjectNoise(object):
return repr_str return repr_str
@PIPELINES.register_module @PIPELINES.register_module()
class GlobalRotScale(object): class GlobalRotScale(object):
def __init__(self, def __init__(self,
...@@ -241,7 +241,7 @@ class GlobalRotScale(object): ...@@ -241,7 +241,7 @@ class GlobalRotScale(object):
return repr_str return repr_str
@PIPELINES.register_module @PIPELINES.register_module()
class PointShuffle(object): class PointShuffle(object):
def __call__(self, input_dict): def __call__(self, input_dict):
...@@ -252,7 +252,7 @@ class PointShuffle(object): ...@@ -252,7 +252,7 @@ class PointShuffle(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module @PIPELINES.register_module()
class ObjectRangeFilter(object): class ObjectRangeFilter(object):
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range):
...@@ -304,7 +304,7 @@ class ObjectRangeFilter(object): ...@@ -304,7 +304,7 @@ class ObjectRangeFilter(object):
return repr_str return repr_str
@PIPELINES.register_module @PIPELINES.register_module()
class PointsRangeFilter(object): class PointsRangeFilter(object):
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range):
......
...@@ -8,7 +8,7 @@ from mmdet.models import HEADS ...@@ -8,7 +8,7 @@ from mmdet.models import HEADS
from .second_head import SECONDHead from .second_head import SECONDHead
@HEADS.register_module @HEADS.register_module()
class Anchor3DVeloHead(SECONDHead): class Anchor3DVeloHead(SECONDHead):
"""Anchor-based head for 3D anchor with velocity """Anchor-based head for 3D anchor with velocity
Args: Args:
......
...@@ -9,7 +9,7 @@ from mmdet.models import HEADS ...@@ -9,7 +9,7 @@ from mmdet.models import HEADS
from .second_head import SECONDHead from .second_head import SECONDHead
@HEADS.register_module @HEADS.register_module()
class PartA2RPNHead(SECONDHead): class PartA2RPNHead(SECONDHead):
"""rpn head for PartA2 """rpn head for PartA2
......
...@@ -13,7 +13,7 @@ from ..builder import build_loss ...@@ -13,7 +13,7 @@ from ..builder import build_loss
from .train_mixins import AnchorTrainMixin from .train_mixins import AnchorTrainMixin
@HEADS.register_module @HEADS.register_module()
class SECONDHead(nn.Module, AnchorTrainMixin): class SECONDHead(nn.Module, AnchorTrainMixin):
"""Anchor-based head for VoxelNet detectors. """Anchor-based head for VoxelNet detectors.
......
from functools import partial from functools import partial
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from mmdet.models import BACKBONES from mmdet.models import BACKBONES
from mmdet.ops import build_norm_layer
class Empty(nn.Module): class Empty(nn.Module):
...@@ -20,7 +20,7 @@ class Empty(nn.Module): ...@@ -20,7 +20,7 @@ class Empty(nn.Module):
return args return args
@BACKBONES.register_module @BACKBONES.register_module()
class SECOND(nn.Module): class SECOND(nn.Module):
"""Compare with RPN, RPNV2 support arbitrary number of stage. """Compare with RPN, RPNV2 support arbitrary number of stage.
""" """
......
...@@ -5,7 +5,7 @@ from mmdet.models import DETECTORS ...@@ -5,7 +5,7 @@ from mmdet.models import DETECTORS
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXFasterRCNN(MVXTwoStageDetector): class DynamicMVXFasterRCNN(MVXTwoStageDetector):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -42,7 +42,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -42,7 +42,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
return points, coors_batch return points, coors_batch
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN): class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -62,7 +62,7 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN): ...@@ -62,7 +62,7 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
return x return x
@DETECTORS.register_module @DETECTORS.register_module()
class MVXFasterRCNNV2(MVXTwoStageDetector): class MVXFasterRCNNV2(MVXTwoStageDetector):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -84,7 +84,7 @@ class MVXFasterRCNNV2(MVXTwoStageDetector): ...@@ -84,7 +84,7 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
return x return x
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXFasterRCNNV3(DynamicMVXFasterRCNN): class DynamicMVXFasterRCNNV3(DynamicMVXFasterRCNN):
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
...@@ -8,7 +8,7 @@ from .. import builder ...@@ -8,7 +8,7 @@ from .. import builder
from .base import BaseDetector from .base import BaseDetector
@DETECTORS.register_module @DETECTORS.register_module()
class MVXSingleStageDetector(BaseDetector): class MVXSingleStageDetector(BaseDetector):
def __init__(self, def __init__(self,
...@@ -162,7 +162,7 @@ class MVXSingleStageDetector(BaseDetector): ...@@ -162,7 +162,7 @@ class MVXSingleStageDetector(BaseDetector):
raise NotImplementedError raise NotImplementedError
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXNet(MVXSingleStageDetector): class DynamicMVXNet(MVXSingleStageDetector):
def __init__(self, def __init__(self,
...@@ -230,7 +230,7 @@ class DynamicMVXNet(MVXSingleStageDetector): ...@@ -230,7 +230,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
return points, coors_batch return points, coors_batch
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXNetV2(DynamicMVXNet): class DynamicMVXNetV2(DynamicMVXNet):
def __init__(self, def __init__(self,
...@@ -281,7 +281,7 @@ class DynamicMVXNetV2(DynamicMVXNet): ...@@ -281,7 +281,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
return x return x
@DETECTORS.register_module @DETECTORS.register_module()
class DynamicMVXNetV3(DynamicMVXNet): class DynamicMVXNetV3(DynamicMVXNet):
def __init__(self, def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment