Commit e3b5253b authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Update all registries and fix some ut problems

parent 8dd8da12
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmdet3d.registry import TASK_UTILS
from mmdet.core.bbox import bbox_overlaps from mmdet.core.bbox import bbox_overlaps
from mmdet.core.bbox.iou_calculators.builder import IOU_CALCULATORS
from ..structures import get_box_type from ..structures import get_box_type
@IOU_CALCULATORS.register_module() @TASK_UTILS.register_module()
class BboxOverlapsNearest3D(object): class BboxOverlapsNearest3D(object):
"""Nearest 3D IoU Calculator. """Nearest 3D IoU Calculator.
...@@ -54,7 +54,7 @@ class BboxOverlapsNearest3D(object): ...@@ -54,7 +54,7 @@ class BboxOverlapsNearest3D(object):
return repr_str return repr_str
@IOU_CALCULATORS.register_module() @TASK_UTILS.register_module()
class BboxOverlaps3D(object): class BboxOverlaps3D(object):
"""3D IoU Calculator. """3D IoU Calculator.
...@@ -176,7 +176,7 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'): ...@@ -176,7 +176,7 @@ def bbox_overlaps_3d(bboxes1, bboxes2, mode='iou', coordinate='camera'):
return bboxes1.overlaps(bboxes1, bboxes2, mode=mode) return bboxes1.overlaps(bboxes1, bboxes2, mode=mode)
@IOU_CALCULATORS.register_module() @TASK_UTILS.register_module()
class AxisAlignedBboxOverlaps3D(object): class AxisAlignedBboxOverlaps3D(object):
"""Axis-aligned 3D Overlaps (IoU) Calculator.""" """Axis-aligned 3D Overlaps (IoU) Calculator."""
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from mmdet.core.bbox.builder import BBOX_SAMPLERS from mmdet3d.registry import TASK_UTILS
from . import RandomSampler, SamplingResult from . import RandomSampler, SamplingResult
@BBOX_SAMPLERS.register_module() @TASK_UTILS.register_module()
class IoUNegPiecewiseSampler(RandomSampler): class IoUNegPiecewiseSampler(RandomSampler):
"""IoU Piece-wise Sampling. """IoU Piece-wise Sampling.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import build_dataloader
from .builder import DATASETS, PIPELINES, build_dataset from .builder import DATASETS, PIPELINES, build_dataset
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset from .custom_3d_seg import Custom3DSegDataset
...@@ -29,17 +28,17 @@ from .utils import get_loading_pipeline ...@@ -29,17 +28,17 @@ from .utils import get_loading_pipeline
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
__all__ = [ __all__ = [
'KittiDataset', 'KittiMonoDataset', 'build_dataloader', 'DATASETS', 'KittiDataset', 'KittiMonoDataset', 'DATASETS', 'build_dataset',
'build_dataset', 'NuScenesDataset', 'NuScenesMonoDataset', 'LyftDataset', 'NuScenesDataset', 'NuScenesMonoDataset', 'LyftDataset', 'ObjectSample',
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'ObjectRangeFilter', 'PointsRangeFilter', 'LoadPointsFromFile',
'LoadPointsFromFile', 'S3DISSegDataset', 'S3DISDataset', 'S3DISSegDataset', 'S3DISDataset', 'NormalizePointsColor',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample', 'IndoorPatchPointSample', 'IndoorPointSample', 'PointSample',
'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetDataset', 'ScanNetSegDataset', 'ScanNetInstanceSegDataset', 'ScanNetSegDataset', 'ScanNetInstanceSegDataset', 'SemanticKITTIDataset',
'SemanticKITTIDataset', 'Custom3DDataset', 'Custom3DSegDataset', 'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter', 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor', 'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize', 'ObjectNameFilter', 'AffineResize', 'RandomShiftScale',
'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES' 'LoadPointsFromDict', 'PIPELINES'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import platform import platform
from mmcv.utils import build_from_cfg from mmdet3d.registry import DATASETS, TRANSFORMS
from mmdet3d.registry import DATASETS
from mmdet.datasets.builder import _concat_dataset from mmdet.datasets.builder import _concat_dataset
if platform.system() != 'Windows': if platform.system() != 'Windows':
...@@ -15,9 +13,8 @@ if platform.system() != 'Windows': ...@@ -15,9 +13,8 @@ if platform.system() != 'Windows':
soft_limit = min(max(4096, base_soft_limit), hard_limit) soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
OBJECTSAMPLERS = Registry('Object sampler') OBJECTSAMPLERS = TRANSFORMS
DATASETS = Registry('dataset') PIPELINES = TRANSFORMS
PIPELINES = Registry('pipeline')
def build_dataset(cfg, default_args=None): def build_dataset(cfg, default_args=None):
...@@ -40,8 +37,7 @@ def build_dataset(cfg, default_args=None): ...@@ -40,8 +37,7 @@ def build_dataset(cfg, default_args=None):
dataset = CBGSDataset(build_dataset(cfg['dataset'], default_args)) dataset = CBGSDataset(build_dataset(cfg['dataset'], default_args))
elif isinstance(cfg.get('ann_file'), (list, tuple)): elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args) dataset = _concat_dataset(cfg, default_args)
elif cfg['type'] in DATASETS._module_dict.keys():
dataset = build_from_cfg(cfg, DATASETS, default_args)
else: else:
dataset = build_from_cfg(cfg, MMDET_DATASETS, default_args) dataset = DATASETS.build(cfg, default_args=default_args)
return dataset return dataset
...@@ -7,8 +7,8 @@ import mmcv ...@@ -7,8 +7,8 @@ import mmcv
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmdet3d.registry import DATASETS
from ..core.bbox import get_box_type from ..core.bbox import get_box_type
from .builder import DATASETS
from .pipelines import Compose from .pipelines import Compose
from .utils import extract_result_dict, get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
......
...@@ -7,14 +7,12 @@ import mmcv ...@@ -7,14 +7,12 @@ import mmcv
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmseg.datasets import DATASETS as SEG_DATASETS from mmdet3d.registry import DATASETS
from .builder import DATASETS
from .pipelines import Compose from .pipelines import Compose
from .utils import extract_result_dict, get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class Custom3DSegDataset(Dataset): class Custom3DSegDataset(Dataset):
"""Customized 3D dataset for semantic segmentation task. """Customized 3D dataset for semantic segmentation task.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
from .builder import DATASETS from mmdet3d.registry import DATASETS
@DATASETS.register_module() @DATASETS.register_module()
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet.datasets import CustomDataset from mmdet3d.datasets import CustomDataset
from .builder import DATASETS from mmdet3d.registry import DATASETS
@DATASETS.register_module() @DATASETS.register_module()
......
...@@ -9,10 +9,10 @@ import numpy as np ...@@ -9,10 +9,10 @@ import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet3d.registry import DATASETS
from ..core import show_multi_modality_result, show_result from ..core import show_multi_modality_result, show_result
from ..core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode, from ..core.bbox import (Box3DMode, CameraInstance3DBoxes, Coord3DMode,
LiDARInstance3DBoxes, points_cam2img) LiDARInstance3DBoxes, points_cam2img)
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .pipelines import Compose from .pipelines import Compose
......
...@@ -8,8 +8,8 @@ import numpy as np ...@@ -8,8 +8,8 @@ import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from mmdet3d.registry import DATASETS
from ..core.bbox import Box3DMode, CameraInstance3DBoxes, points_cam2img from ..core.bbox import Box3DMode, CameraInstance3DBoxes, points_cam2img
from .builder import DATASETS
from .nuscenes_mono_dataset import NuScenesMonoDataset from .nuscenes_mono_dataset import NuScenesMonoDataset
...@@ -35,8 +35,6 @@ class KittiMonoDataset(NuScenesMonoDataset): ...@@ -35,8 +35,6 @@ class KittiMonoDataset(NuScenesMonoDataset):
def __init__(self, def __init__(self,
data_root, data_root,
info_file, info_file,
ann_file,
pipeline,
load_interval=1, load_interval=1,
with_velocity=False, with_velocity=False,
eval_version=None, eval_version=None,
...@@ -44,8 +42,6 @@ class KittiMonoDataset(NuScenesMonoDataset): ...@@ -44,8 +42,6 @@ class KittiMonoDataset(NuScenesMonoDataset):
**kwargs): **kwargs):
super().__init__( super().__init__(
data_root=data_root, data_root=data_root,
ann_file=ann_file,
pipeline=pipeline,
load_interval=load_interval, load_interval=load_interval,
with_velocity=with_velocity, with_velocity=with_velocity,
eval_version=eval_version, eval_version=eval_version,
......
...@@ -11,9 +11,9 @@ from lyft_dataset_sdk.utils.data_classes import Box as LyftBox ...@@ -11,9 +11,9 @@ from lyft_dataset_sdk.utils.data_classes import Box as LyftBox
from pyquaternion import Quaternion from pyquaternion import Quaternion
from mmdet3d.core.evaluation.lyft_eval import lyft_eval from mmdet3d.core.evaluation.lyft_eval import lyft_eval
from mmdet3d.registry import DATASETS
from ..core import show_result from ..core import show_result
from ..core.bbox import Box3DMode, Coord3DMode, LiDARInstance3DBoxes from ..core.bbox import Box3DMode, Coord3DMode, LiDARInstance3DBoxes
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .pipelines import Compose from .pipelines import Compose
......
...@@ -7,9 +7,9 @@ import numpy as np ...@@ -7,9 +7,9 @@ import numpy as np
import pyquaternion import pyquaternion
from nuscenes.utils.data_classes import Box as NuScenesBox from nuscenes.utils.data_classes import Box as NuScenesBox
from mmdet3d.registry import DATASETS
from ..core import show_result from ..core import show_result
from ..core.bbox import Box3DMode, Coord3DMode, LiDARInstance3DBoxes from ..core.bbox import Box3DMode, Coord3DMode, LiDARInstance3DBoxes
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .pipelines import Compose from .pipelines import Compose
......
...@@ -11,10 +11,10 @@ import torch ...@@ -11,10 +11,10 @@ import torch
from nuscenes.utils.data_classes import Box as NuScenesBox from nuscenes.utils.data_classes import Box as NuScenesBox
from mmdet3d.core import bbox3d2result, box3d_multiclass_nms, xywhr2xyxyr from mmdet3d.core import bbox3d2result, box3d_multiclass_nms, xywhr2xyxyr
from mmdet3d.registry import DATASETS
from mmdet.datasets import CocoDataset from mmdet.datasets import CocoDataset
from ..core import show_multi_modality_result from ..core import show_multi_modality_result
from ..core.bbox import CameraInstance3DBoxes, get_box_type from ..core.bbox import CameraInstance3DBoxes, get_box_type
from .builder import DATASETS
from .pipelines import Compose from .pipelines import Compose
from .utils import extract_result_dict, get_loading_pipeline from .utils import extract_result_dict, get_loading_pipeline
...@@ -77,8 +77,6 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -77,8 +77,6 @@ class NuScenesMonoDataset(CocoDataset):
def __init__(self, def __init__(self,
data_root, data_root,
ann_file,
pipeline,
load_interval=1, load_interval=1,
with_velocity=True, with_velocity=True,
modality=None, modality=None,
...@@ -86,46 +84,9 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -86,46 +84,9 @@ class NuScenesMonoDataset(CocoDataset):
eval_version='detection_cvpr_2019', eval_version='detection_cvpr_2019',
use_valid_flag=False, use_valid_flag=False,
version='v1.0-trainval', version='v1.0-trainval',
classes=None, **kwargs):
img_prefix='', super().__init__(**kwargs)
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
file_client_args=dict(backend='disk')):
self.ann_file = ann_file
self.data_root = data_root self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)
self.file_client = mmcv.FileClient(**file_client_args)
# load annotations (and proposals)
with self.file_client.get_local_path(self.ann_file) as local_path:
self.data_infos = self.load_annotations(local_path)
if self.proposal_file is not None:
with self.file_client.get_local_path(
self.proposal_file) as local_path:
self.proposals = self.load_proposals(local_path)
else:
self.proposals = None
# filter images too small and containing no annotations
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# set group flag for the sampler
self._set_group_flag()
# processing pipeline
self.pipeline = Compose(pipeline)
self.load_interval = load_interval self.load_interval = load_interval
self.with_velocity = with_velocity self.with_velocity = with_velocity
self.modality = modality self.modality = modality
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import collections import collections
from mmcv.utils import build_from_cfg from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.builder import PIPELINES as MMDET_PIPELINES
from ..builder import PIPELINES
@TRANSFORMS.register_module()
@PIPELINES.register_module()
class Compose: class Compose:
"""Compose multiple transforms sequentially. The pipeline registry of """Compose multiple transforms sequentially.
mmdet3d separates with mmdet, however, sometimes we may need to use mmdet's
pipeline. So the class is rewritten to be able to use pipelines from both
mmdet3d and mmdet.
Args: Args:
transforms (Sequence[dict | callable]): Sequence of transform object or transforms (Sequence[dict | callable]): Sequence of transform object or
...@@ -24,11 +18,7 @@ class Compose: ...@@ -24,11 +18,7 @@ class Compose:
self.transforms = [] self.transforms = []
for transform in transforms: for transform in transforms:
if isinstance(transform, dict): if isinstance(transform, dict):
_, key = PIPELINES.split_scope_key(transform['type']) transform = TRANSFORMS.build(transform)
if key in PIPELINES._module_dict.keys():
transform = build_from_cfg(transform, PIPELINES)
else:
transform = build_from_cfg(transform, MMDET_PIPELINES)
self.transforms.append(transform) self.transforms.append(transform)
elif callable(transform): elif callable(transform):
self.transforms.append(transform) self.transforms.append(transform)
...@@ -54,7 +44,10 @@ class Compose: ...@@ -54,7 +44,10 @@ class Compose:
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + '('
for t in self.transforms: for t in self.transforms:
str_ = t.__repr__()
if 'Compose(' in str_:
str_ = str_.replace('\n', '\n ')
format_string += '\n' format_string += '\n'
format_string += f' {t}' format_string += f' {str_}'
format_string += '\n)' format_string += '\n)'
return format_string return format_string
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
from mmdet3d.datasets.pipelines import data_augment_utils from mmdet3d.datasets.pipelines import data_augment_utils
from ..builder import OBJECTSAMPLERS, PIPELINES from mmdet3d.registry import TRANSFORMS
class BatchSampler: class BatchSampler:
...@@ -78,7 +78,7 @@ class BatchSampler: ...@@ -78,7 +78,7 @@ class BatchSampler:
return [self._sampled_list[i] for i in indices] return [self._sampled_list[i] for i in indices]
@OBJECTSAMPLERS.register_module() @TRANSFORMS.register_module()
class DataBaseSampler(object): class DataBaseSampler(object):
"""Class for sampling data from the ground truth database. """Class for sampling data from the ground truth database.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import numpy as np
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from mmcv.transforms import to_tensor
from mmdet3d.core.bbox import BaseInstance3DBoxes from mmdet3d.core.bbox import BaseInstance3DBoxes
from mmdet3d.core.points import BasePoints from mmdet3d.core.points import BasePoints
from mmdet.datasets.pipelines import to_tensor from mmdet3d.registry import TRANSFORMS
from ..builder import PIPELINES
@PIPELINES.register_module() @TRANSFORMS.register_module()
class DefaultFormatBundle(object): class DefaultFormatBundle(object):
"""Default formatting bundle. """Default formatting bundle.
...@@ -79,7 +79,7 @@ class DefaultFormatBundle(object): ...@@ -79,7 +79,7 @@ class DefaultFormatBundle(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Collect3D(object): class Collect3D(object):
"""Collect data from the loader relevant to the specific task. """Collect data from the loader relevant to the specific task.
...@@ -170,7 +170,7 @@ class Collect3D(object): ...@@ -170,7 +170,7 @@ class Collect3D(object):
f'(keys={self.keys}, meta_keys={self.meta_keys})' f'(keys={self.keys}, meta_keys={self.meta_keys})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class DefaultFormatBundle3D(DefaultFormatBundle): class DefaultFormatBundle3D(DefaultFormatBundle):
"""Default formatting bundle. """Default formatting bundle.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.transforms import LoadImageFromFile
from mmdet3d.core.points import BasePoints, get_points_type from mmdet3d.core.points import BasePoints, get_points_type
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile from mmdet3d.registry import TRANSFORMS
from ..builder import PIPELINES from mmdet.datasets.pipelines import LoadAnnotations
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadMultiViewImageFromFiles(object): class LoadMultiViewImageFromFiles(object):
"""Load multi channel images from a list of separate channel files. """Load multi channel images from a list of separate channel files.
...@@ -72,7 +73,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -72,7 +73,7 @@ class LoadMultiViewImageFromFiles(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadImageFromFileMono3D(LoadImageFromFile): class LoadImageFromFileMono3D(LoadImageFromFile):
"""Load an image from file in monocular 3D object detection. Compared to 2D """Load an image from file in monocular 3D object detection. Compared to 2D
detection, additional camera parameters need to be loaded. detection, additional camera parameters need to be loaded.
...@@ -96,7 +97,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile): ...@@ -96,7 +97,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile):
return results return results
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadPointsFromMultiSweeps(object): class LoadPointsFromMultiSweeps(object):
"""Load points from multiple sweeps. """Load points from multiple sweeps.
...@@ -238,7 +239,7 @@ class LoadPointsFromMultiSweeps(object): ...@@ -238,7 +239,7 @@ class LoadPointsFromMultiSweeps(object):
return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})' return f'{self.__class__.__name__}(sweeps_num={self.sweeps_num})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class PointSegClassMapping(object): class PointSegClassMapping(object):
"""Map original semantic class to valid category ids. """Map original semantic class to valid category ids.
...@@ -293,7 +294,7 @@ class PointSegClassMapping(object): ...@@ -293,7 +294,7 @@ class PointSegClassMapping(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class NormalizePointsColor(object): class NormalizePointsColor(object):
"""Normalize color of points. """Normalize color of points.
...@@ -334,7 +335,7 @@ class NormalizePointsColor(object): ...@@ -334,7 +335,7 @@ class NormalizePointsColor(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadPointsFromFile(object): class LoadPointsFromFile(object):
"""Load Points From File. """Load Points From File.
...@@ -460,7 +461,7 @@ class LoadPointsFromFile(object): ...@@ -460,7 +461,7 @@ class LoadPointsFromFile(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadPointsFromDict(LoadPointsFromFile): class LoadPointsFromDict(LoadPointsFromFile):
"""Load Points From Dict.""" """Load Points From Dict."""
...@@ -469,7 +470,7 @@ class LoadPointsFromDict(LoadPointsFromFile): ...@@ -469,7 +470,7 @@ class LoadPointsFromDict(LoadPointsFromFile):
return results return results
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadAnnotations3D(LoadAnnotations): class LoadAnnotations3D(LoadAnnotations):
"""Load Annotations3D. """Load Annotations3D.
......
...@@ -4,118 +4,11 @@ from copy import deepcopy ...@@ -4,118 +4,11 @@ from copy import deepcopy
import mmcv import mmcv
from ..builder import PIPELINES from mmdet3d.registry import TRANSFORMS
from .compose import Compose from .compose import Compose
@PIPELINES.register_module() @TRANSFORMS.register_module()
class MultiScaleFlipAug:
"""Test-time augmentation with multiple scales and flipping. An example
configuration is as followed:
.. code-block::
img_scale=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
)
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float | list[float] | None): Scale factors for resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
flip_direction is a list, multiple flip augmentations will be
applied. It has no effect when flip == False. Default:
"horizontal".
"""
def __init__(self,
transforms,
img_scale=None,
scale_factor=None,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
assert (img_scale is None) ^ (scale_factor is None), (
'Must have but only one variable can be set')
if img_scale is not None:
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
else:
self.img_scale = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'
self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
aug_data = []
flip_args = [(False, None)]
if self.flip:
flip_args += [(True, direction)
for direction in self.flip_direction]
for scale in self.img_scale:
for flip, direction in flip_args:
_results = results.copy()
_results[self.scale_key] = scale
_results['flip'] = flip
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
repr_str += f'flip_direction={self.flip_direction})'
return repr_str
@PIPELINES.register_module()
class MultiScaleFlipAug3D(object): class MultiScaleFlipAug3D(object):
"""Test-time augmentation with multiple scales and flipping. """Test-time augmentation with multiple scales and flipping.
......
...@@ -12,13 +12,12 @@ from mmengine.registry import build_from_cfg ...@@ -12,13 +12,12 @@ from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops) LiDARInstance3DBoxes, box_np_ops)
from mmdet3d.registry import OBJECTSAMPLERS, TRANSFORMS from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip from mmdet.datasets.pipelines import RandomFlip
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomDropPointsColor(object): class RandomDropPointsColor(object):
r"""Randomly set the color of points to all zeros. r"""Randomly set the color of points to all zeros.
...@@ -68,7 +67,7 @@ class RandomDropPointsColor(object): ...@@ -68,7 +67,7 @@ class RandomDropPointsColor(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomFlip3D(RandomFlip): class RandomFlip3D(RandomFlip):
"""Flip the points & bbox. """Flip the points & bbox.
...@@ -193,7 +192,7 @@ class RandomFlip3D(RandomFlip): ...@@ -193,7 +192,7 @@ class RandomFlip3D(RandomFlip):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomJitterPoints(object): class RandomJitterPoints(object):
"""Randomly jitter point coordinates. """Randomly jitter point coordinates.
...@@ -299,7 +298,7 @@ class ObjectSample(BaseTransform): ...@@ -299,7 +298,7 @@ class ObjectSample(BaseTransform):
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
db_sampler['type'] = 'DataBaseSampler' db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS) self.db_sampler = build_from_cfg(db_sampler, TRANSFORMS)
self.use_ground_plane = use_ground_plane self.use_ground_plane = use_ground_plane
@staticmethod @staticmethod
...@@ -471,7 +470,7 @@ class ObjectNoise(BaseTransform): ...@@ -471,7 +470,7 @@ class ObjectNoise(BaseTransform):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class GlobalAlignment(object): class GlobalAlignment(object):
"""Apply global alignment to 3D scene points by rotation and translation. """Apply global alignment to 3D scene points by rotation and translation.
...@@ -558,7 +557,7 @@ class GlobalAlignment(object): ...@@ -558,7 +557,7 @@ class GlobalAlignment(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class GlobalRotScaleTrans(object): class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene. """Apply global rotation, scaling and translation to a 3D scene.
...@@ -724,7 +723,7 @@ class GlobalRotScaleTrans(object): ...@@ -724,7 +723,7 @@ class GlobalRotScaleTrans(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class PointShuffle(object): class PointShuffle(object):
"""Shuffle input points.""" """Shuffle input points."""
...@@ -756,7 +755,7 @@ class PointShuffle(object): ...@@ -756,7 +755,7 @@ class PointShuffle(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ObjectRangeFilter(object): class ObjectRangeFilter(object):
"""Filter objects by the range. """Filter objects by the range.
...@@ -808,7 +807,7 @@ class ObjectRangeFilter(object): ...@@ -808,7 +807,7 @@ class ObjectRangeFilter(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class PointsRangeFilter(object): class PointsRangeFilter(object):
"""Filter points by the range. """Filter points by the range.
...@@ -853,7 +852,7 @@ class PointsRangeFilter(object): ...@@ -853,7 +852,7 @@ class PointsRangeFilter(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ObjectNameFilter(object): class ObjectNameFilter(object):
"""Filter GT objects by their names. """Filter GT objects by their names.
...@@ -1009,7 +1008,7 @@ class PointSample(BaseTransform): ...@@ -1009,7 +1008,7 @@ class PointSample(BaseTransform):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class IndoorPointSample(PointSample): class IndoorPointSample(PointSample):
"""Indoor point sample. """Indoor point sample.
...@@ -1026,7 +1025,7 @@ class IndoorPointSample(PointSample): ...@@ -1026,7 +1025,7 @@ class IndoorPointSample(PointSample):
super(IndoorPointSample, self).__init__(*args, **kwargs) super(IndoorPointSample, self).__init__(*args, **kwargs)
@PIPELINES.register_module() @TRANSFORMS.register_module()
class IndoorPatchPointSample(object): class IndoorPatchPointSample(object):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https:// r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_. github.com/charlesq34/pointnet2/blob/master/scannet/scannet_dataset.py>`_.
...@@ -1271,7 +1270,7 @@ class IndoorPatchPointSample(object): ...@@ -1271,7 +1270,7 @@ class IndoorPatchPointSample(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class BackgroundPointsFilter(object): class BackgroundPointsFilter(object):
"""Filter background points near the bounding box. """Filter background points near the bounding box.
...@@ -1336,7 +1335,7 @@ class BackgroundPointsFilter(object): ...@@ -1336,7 +1335,7 @@ class BackgroundPointsFilter(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class VoxelBasedPointSampler(object): class VoxelBasedPointSampler(object):
"""Voxel based point sampler. """Voxel based point sampler.
...@@ -1478,7 +1477,7 @@ class VoxelBasedPointSampler(object): ...@@ -1478,7 +1477,7 @@ class VoxelBasedPointSampler(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class AffineResize(object): class AffineResize(object):
"""Get the affine transform matrices to the target size. """Get the affine transform matrices to the target size.
...@@ -1674,13 +1673,13 @@ class AffineResize(object): ...@@ -1674,13 +1673,13 @@ class AffineResize(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomShiftScale(object): class RandomShiftScale(object):
"""Random shift scale. """Random shift scale.
Different from the normal shift and scale function, it doesn't Different from the normal shift and scale function, it doesn't
directly shift or scale image. It can record the shift and scale directly shift or scale image. It can record the shift and scale
infos into loading pipelines. It's designed to be used with infos into loading TRANSFORMS. It's designed to be used with
AffineResize together. AffineResize together.
Args: Args:
......
...@@ -5,8 +5,7 @@ import numpy as np ...@@ -5,8 +5,7 @@ import numpy as np
from mmdet3d.core import show_seg_result from mmdet3d.core import show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmseg.datasets import DATASETS as SEG_DATASETS from mmdet3d.registry import DATASETS
from .builder import DATASETS
from .custom_3d import Custom3DDataset from .custom_3d import Custom3DDataset
from .custom_3d_seg import Custom3DSegDataset from .custom_3d_seg import Custom3DSegDataset
from .pipelines import Compose from .pipelines import Compose
...@@ -308,7 +307,6 @@ class _S3DISSegDataset(Custom3DSegDataset): ...@@ -308,7 +307,6 @@ class _S3DISSegDataset(Custom3DSegDataset):
@DATASETS.register_module() @DATASETS.register_module()
@SEG_DATASETS.register_module()
class S3DISSegDataset(_S3DISSegDataset): class S3DISSegDataset(_S3DISSegDataset):
r"""S3DIS Dataset for Semantic Segmentation Task. r"""S3DIS Dataset for Semantic Segmentation Task.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment