Commit 0b147600 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Fix PointSample + ObjectSample + ObjectNoise

parent 65df06b8
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import platform import platform
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import build_from_cfg
from mmdet.datasets import DATASETS as MMDET_DATASETS from mmdet.datasets import DATASETS as MMDET_DATASETS
from mmdet.datasets.builder import _concat_dataset from mmdet.datasets.builder import _concat_dataset
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict
import cv2 import cv2
import numpy as np import numpy as np
from mmcv import is_tuple_of from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg from mmcv.transforms import BaseTransform
from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops) LiDARInstance3DBoxes, box_np_ops)
from mmdet3d.registry import OBJECTSAMPLERS, TRANSFORMS
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip from mmdet.datasets.pipelines import RandomFlip
from ..builder import OBJECTSAMPLERS, PIPELINES
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
...@@ -258,16 +261,36 @@ class RandomJitterPoints(object): ...@@ -258,16 +261,36 @@ class RandomJitterPoints(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ObjectSample(object): class ObjectSample(BaseTransform):
"""Sample GT objects to the data. """Sample GT objects to the data.
Required Keys:
- points
- ann_info
- gt_bboxes_3d
- gt_labels_3d
- img (optional)
- gt_bboxes (optional)
Modified Keys:
- points
- gt_bboxes_3d
- gt_labels_3d
- img (optional)
- gt_bboxes (optional)
Added Keys:
- plane (optional)
Args: Args:
db_sampler (dict): Config dict of the database sampler. db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste. This should be true when applying multi-modality cut-and-paste.
Defaults to False. Defaults to False.
use_ground_plane (bool): Whether to use gound plane to adjust the use_ground_plane (bool): Whether to use ground plane to adjust the
3D labels. 3D labels.
""" """
...@@ -294,8 +317,8 @@ class ObjectSample(object): ...@@ -294,8 +317,8 @@ class ObjectSample(object):
points = points[np.logical_not(masks.any(-1))] points = points[np.logical_not(masks.any(-1))]
return points return points
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to sample ground truth objects to the data. """Transform function to sample ground truth objects to the data.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -373,10 +396,20 @@ class ObjectSample(object): ...@@ -373,10 +396,20 @@ class ObjectSample(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ObjectNoise(object): class ObjectNoise(BaseTransform):
"""Apply noise to each GT objects in the scene. """Apply noise to each GT objects in the scene.
Required Keys:
- points
- gt_bboxes_3d
Modified Keys:
- points
- gt_bboxes_3d
Args: Args:
translation_std (list[float], optional): Standard deviation of the translation_std (list[float], optional): Standard deviation of the
distribution where translation noise are sampled from. distribution where translation noise are sampled from.
...@@ -399,8 +432,8 @@ class ObjectNoise(object): ...@@ -399,8 +432,8 @@ class ObjectNoise(object):
self.rot_range = rot_range self.rot_range = rot_range
self.num_try = num_try self.num_try = num_try
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to apply noise to each ground truth in the scene. """Transform function to apply noise to each ground truth in the scene.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -857,12 +890,22 @@ class ObjectNameFilter(object): ...@@ -857,12 +890,22 @@ class ObjectNameFilter(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class PointSample(object): class PointSample(BaseTransform):
"""Point sample. """Point sample.
Sampling data to a certain number. Sampling data to a certain number.
Required Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Args: Args:
num_points (int): Number of points to be sampled. num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points. sample_range (float, optional): The range where to sample points.
...@@ -925,8 +968,8 @@ class PointSample(object): ...@@ -925,8 +968,8 @@ class PointSample(object):
else: else:
return points[choices] return points[choices]
def __call__(self, results): def transform(self, input_dict: Dict) -> Dict:
"""Call function to sample points to in indoor scenes. """Transform function to sample points to in indoor scenes.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -934,27 +977,27 @@ class PointSample(object): ...@@ -934,27 +977,27 @@ class PointSample(object):
dict: Results after sampling, 'points', 'pts_instance_mask' dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = input_dict['points']
points, choices = self._points_random_sampling( points, choices = self._points_random_sampling(
points, points,
self.num_points, self.num_points,
self.sample_range, self.sample_range,
self.replace, self.replace,
return_choices=True) return_choices=True)
results['points'] = points input_dict['points'] = points
pts_instance_mask = results.get('pts_instance_mask', None) pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = results.get('pts_semantic_mask', None) pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None: if pts_instance_mask is not None:
pts_instance_mask = pts_instance_mask[choices] pts_instance_mask = pts_instance_mask[choices]
results['pts_instance_mask'] = pts_instance_mask input_dict['pts_instance_mask'] = pts_instance_mask
if pts_semantic_mask is not None: if pts_semantic_mask is not None:
pts_semantic_mask = pts_semantic_mask[choices] pts_semantic_mask = pts_semantic_mask[choices]
results['pts_semantic_mask'] = pts_semantic_mask input_dict['pts_semantic_mask'] = pts_semantic_mask
return results return input_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
......
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import OBJECTSAMPLERS, TRANSFORMS
__all__ = ['TRANSFORMS', 'OBJECTSAMPLERS']
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import Registry
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
OBJECTSAMPLERS = Registry('Object sampler')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment