Commit 7ed412b7 authored by zhangwenwei's avatar zhangwenwei
Browse files

Migrate to MMDet V2.0-stable using mmcv cnn bricks

parent 99db60dd
......@@ -374,7 +374,7 @@ from .coco import CocoDataset
from .registry import DATASETS
@DATASETS.register_module
@DATASETS.register_module()
class MyDataset(CocoDataset):
CLASSES = ('a', 'b', 'c', 'd', 'e')
......@@ -444,7 +444,7 @@ from .registry import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
@OPTIMIZERS.register_module()
class MyOptimizer(Optimizer):
```
......@@ -476,7 +476,7 @@ import torch.nn as nn
from ..registry import BACKBONES
@BACKBONES.register_module
@BACKBONES.register_module()
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
......
......@@ -4,7 +4,7 @@ import torch
from mmdet.core.anchor import ANCHOR_GENERATORS
@ANCHOR_GENERATORS.register_module
@ANCHOR_GENERATORS.register_module()
class Anchor3DRangeGenerator(object):
"""3D Anchor Generator by range
......@@ -183,7 +183,7 @@ class Anchor3DRangeGenerator(object):
return ret
@ANCHOR_GENERATORS.register_module
@ANCHOR_GENERATORS.register_module()
class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
"""Aligned 3D Anchor Generator by range
......
......@@ -4,7 +4,7 @@ from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS
@BBOX_CODERS.register_module
@BBOX_CODERS.register_module()
class DeltaXYZWLHRBBoxCoder(BaseBBoxCoder):
def __init__(self, code_size=7):
......
......@@ -6,7 +6,7 @@ from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS
from .. import box_torch_ops
@IOU_CALCULATORS.register_module
@IOU_CALCULATORS.register_module()
class BboxOverlapsNearest3D(object):
"""Nearest 3D IoU Calculator"""
......@@ -20,7 +20,7 @@ class BboxOverlapsNearest3D(object):
return repr_str
@IOU_CALCULATORS.register_module
@IOU_CALCULATORS.register_module()
class BboxOverlaps3D(object):
"""3D IoU Calculator"""
......
......@@ -5,7 +5,7 @@ from mmdet.utils import get_root_logger
from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
@OPTIMIZER_BUILDERS.register_module()
class CocktailOptimizerConstructor(object):
"""Special constructor for cocktail optimizers.
......
......@@ -3,7 +3,7 @@ from torch.optim import Optimizer
from mmdet.core.optimizer import OPTIMIZERS
@OPTIMIZERS.register_module
@OPTIMIZERS.register_module()
class CocktailOptimizer(Optimizer):
"""Cocktail Optimizer that contains multiple optimizers
......
......@@ -7,7 +7,7 @@ from mmdet.datasets import DATASETS
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@DATASETS.register_module
@DATASETS.register_module()
class RepeatFactorDataset(object):
"""A wrapper of repeated dataset with repeat factor.
......
......@@ -4,7 +4,7 @@ import numpy as np
from mmdet.datasets import DATASETS, CustomDataset
@DATASETS.register_module
@DATASETS.register_module()
class Kitti2DDataset(CustomDataset):
CLASSES = ('car', 'pedestrian', 'cyclist')
......
......@@ -15,7 +15,7 @@ from .pipelines import Compose
from .utils import remove_dontcare
@DATASETS.register_module
@DATASETS.register_module()
class KittiDataset(torch_data.Dataset):
CLASSES = ('car', 'pedestrian', 'cyclist')
......
......@@ -13,7 +13,7 @@ from ..core.bbox import box_np_ops
from .pipelines import Compose
@DATASETS.register_module
@DATASETS.register_module()
class NuScenesDataset(torch_data.Dataset):
NumPointFeatures = 4 # xyz, timestamp. set 4 to use kitti pretrain
NameMapping = {
......
......@@ -52,7 +52,7 @@ class BatchSampler:
return [self._sampled_list[i] for i in indices]
@OBJECTSAMPLERS.register_module
@OBJECTSAMPLERS.register_module()
class DataBaseSampler(object):
def __init__(self, info_path, root_path, rate, prepare, object_rot_range,
......@@ -255,7 +255,7 @@ class DataBaseSampler(object):
return valid_samples
@OBJECTSAMPLERS.register_module
@OBJECTSAMPLERS.register_module()
class MMDataBaseSampler(DataBaseSampler):
def __init__(self,
......
......@@ -7,7 +7,7 @@ from mmdet.datasets.pipelines import to_tensor
PIPELINES._module_dict.pop('DefaultFormatBundle')
@PIPELINES.register_module
@PIPELINES.register_module()
class DefaultFormatBundle(object):
"""Default formatting bundle.
......@@ -59,7 +59,7 @@ class DefaultFormatBundle(object):
return self.__class__.__name__
@PIPELINES.register_module
@PIPELINES.register_module()
class Collect3D(object):
def __init__(self,
......@@ -90,7 +90,7 @@ class Collect3D(object):
self.keys, self.meta_keys)
@PIPELINES.register_module
@PIPELINES.register_module()
class DefaultFormatBundle3D(DefaultFormatBundle):
"""Default formatting bundle.
......
......@@ -6,7 +6,7 @@ import numpy as np
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module
@PIPELINES.register_module()
class LoadPointsFromFile(object):
def __init__(self, points_dim=4, with_reflectivity=True):
......@@ -31,7 +31,7 @@ class LoadPointsFromFile(object):
return repr_str
@PIPELINES.register_module
@PIPELINES.register_module()
class LoadMultiViewImageFromFiles(object):
""" Load multi channel images from a list of separate channel files.
Expects results['filename'] to be a list of filenames
......
......@@ -8,7 +8,7 @@ from ..registry import OBJECTSAMPLERS
from .data_augment_utils import noise_per_object_v3_
@PIPELINES.register_module
@PIPELINES.register_module()
class RandomFlip3D(RandomFlip):
"""Flip the points & bbox.
......@@ -51,7 +51,7 @@ class RandomFlip3D(RandomFlip):
return input_dict
@PIPELINES.register_module
@PIPELINES.register_module()
class ObjectSample(object):
def __init__(self, db_sampler, sample_2d=False):
......@@ -128,7 +128,7 @@ class ObjectSample(object):
return self.__class__.__name__
@PIPELINES.register_module
@PIPELINES.register_module()
class ObjectNoise(object):
def __init__(self,
......@@ -167,7 +167,7 @@ class ObjectNoise(object):
return repr_str
@PIPELINES.register_module
@PIPELINES.register_module()
class GlobalRotScale(object):
def __init__(self,
......@@ -241,7 +241,7 @@ class GlobalRotScale(object):
return repr_str
@PIPELINES.register_module
@PIPELINES.register_module()
class PointShuffle(object):
def __call__(self, input_dict):
......@@ -252,7 +252,7 @@ class PointShuffle(object):
return self.__class__.__name__
@PIPELINES.register_module
@PIPELINES.register_module()
class ObjectRangeFilter(object):
def __init__(self, point_cloud_range):
......@@ -304,7 +304,7 @@ class ObjectRangeFilter(object):
return repr_str
@PIPELINES.register_module
@PIPELINES.register_module()
class PointsRangeFilter(object):
def __init__(self, point_cloud_range):
......
......@@ -8,7 +8,7 @@ from mmdet.models import HEADS
from .second_head import SECONDHead
@HEADS.register_module
@HEADS.register_module()
class Anchor3DVeloHead(SECONDHead):
"""Anchor-based head for 3D anchor with velocity
Args:
......
......@@ -9,7 +9,7 @@ from mmdet.models import HEADS
from .second_head import SECONDHead
@HEADS.register_module
@HEADS.register_module()
class PartA2RPNHead(SECONDHead):
"""rpn head for PartA2
......
......@@ -13,7 +13,7 @@ from ..builder import build_loss
from .train_mixins import AnchorTrainMixin
@HEADS.register_module
@HEADS.register_module()
class SECONDHead(nn.Module, AnchorTrainMixin):
"""Anchor-based head for VoxelNet detectors.
......
from functools import partial
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.runner import load_checkpoint
from mmdet.models import BACKBONES
from mmdet.ops import build_norm_layer
class Empty(nn.Module):
......@@ -20,7 +20,7 @@ class Empty(nn.Module):
return args
@BACKBONES.register_module
@BACKBONES.register_module()
class SECOND(nn.Module):
"""Compare with RPN, RPNV2 support arbitrary number of stage.
"""
......
......@@ -5,7 +5,7 @@ from mmdet.models import DETECTORS
from .mvx_two_stage import MVXTwoStageDetector
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXFasterRCNN(MVXTwoStageDetector):
def __init__(self, **kwargs):
......@@ -42,7 +42,7 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
return points, coors_batch
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
def __init__(self, **kwargs):
......@@ -62,7 +62,7 @@ class DynamicMVXFasterRCNNV2(DynamicMVXFasterRCNN):
return x
@DETECTORS.register_module
@DETECTORS.register_module()
class MVXFasterRCNNV2(MVXTwoStageDetector):
def __init__(self, **kwargs):
......@@ -84,7 +84,7 @@ class MVXFasterRCNNV2(MVXTwoStageDetector):
return x
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXFasterRCNNV3(DynamicMVXFasterRCNN):
def __init__(self, **kwargs):
......
......@@ -8,7 +8,7 @@ from .. import builder
from .base import BaseDetector
@DETECTORS.register_module
@DETECTORS.register_module()
class MVXSingleStageDetector(BaseDetector):
def __init__(self,
......@@ -162,7 +162,7 @@ class MVXSingleStageDetector(BaseDetector):
raise NotImplementedError
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXNet(MVXSingleStageDetector):
def __init__(self,
......@@ -230,7 +230,7 @@ class DynamicMVXNet(MVXSingleStageDetector):
return points, coors_batch
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXNetV2(DynamicMVXNet):
def __init__(self,
......@@ -281,7 +281,7 @@ class DynamicMVXNetV2(DynamicMVXNet):
return x
@DETECTORS.register_module
@DETECTORS.register_module()
class DynamicMVXNetV3(DynamicMVXNet):
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment