Commit 0b147600 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Fix PointSample + ObjectSample + ObjectNoise

parent 65df06b8
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import build_from_cfg
from mmdet.datasets import DATASETS as MMDET_DATASETS
from mmdet.datasets.builder import _concat_dataset
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Dict
import cv2
import numpy as np
from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg
from mmcv.transforms import BaseTransform
from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops)
from mmdet3d.registry import OBJECTSAMPLERS, TRANSFORMS
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip
from ..builder import OBJECTSAMPLERS, PIPELINES
from .data_augment_utils import noise_per_object_v3_
......@@ -258,16 +261,36 @@ class RandomJitterPoints(object):
return repr_str
@PIPELINES.register_module()
class ObjectSample(object):
@TRANSFORMS.register_module()
class ObjectSample(BaseTransform):
"""Sample GT objects to the data.
Required Keys:
- points
- ann_info
- gt_bboxes_3d
- gt_labels_3d
- img (optional)
- gt_bboxes (optional)
Modified Keys:
- points
- gt_bboxes_3d
- gt_labels_3d
- img (optional)
- gt_bboxes (optional)
Added Keys:
- plane (optional)
Args:
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste.
Defaults to False.
use_ground_plane (bool): Whether to use gound plane to adjust the
use_ground_plane (bool): Whether to use ground plane to adjust the
3D labels.
"""
......@@ -294,8 +317,8 @@ class ObjectSample(object):
points = points[np.logical_not(masks.any(-1))]
return points
def __call__(self, input_dict):
"""Call function to sample ground truth objects to the data.
def transform(self, input_dict: dict) -> dict:
"""Transform function to sample ground truth objects to the data.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -373,10 +396,20 @@ class ObjectSample(object):
return repr_str
@PIPELINES.register_module()
class ObjectNoise(object):
@TRANSFORMS.register_module()
class ObjectNoise(BaseTransform):
"""Apply noise to each GT objects in the scene.
Required Keys:
- points
- gt_bboxes_3d
Modified Keys:
- points
- gt_bboxes_3d
Args:
translation_std (list[float], optional): Standard deviation of the
distribution where translation noise are sampled from.
......@@ -399,8 +432,8 @@ class ObjectNoise(object):
self.rot_range = rot_range
self.num_try = num_try
def __call__(self, input_dict):
"""Call function to apply noise to each ground truth in the scene.
def transform(self, input_dict: dict) -> dict:
"""Transform function to apply noise to each ground truth in the scene.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -857,12 +890,22 @@ class ObjectNameFilter(object):
return repr_str
@PIPELINES.register_module()
class PointSample(object):
@TRANSFORMS.register_module()
class PointSample(BaseTransform):
"""Point sample.
Sampling data to a certain number.
Required Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Args:
num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points.
......@@ -925,8 +968,8 @@ class PointSample(object):
else:
return points[choices]
def __call__(self, results):
"""Call function to sample points to in indoor scenes.
def transform(self, input_dict: Dict) -> Dict:
"""Transform function to sample points to in indoor scenes.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -934,27 +977,27 @@ class PointSample(object):
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
points = input_dict['points']
points, choices = self._points_random_sampling(
points,
self.num_points,
self.sample_range,
self.replace,
return_choices=True)
results['points'] = points
input_dict['points'] = points
pts_instance_mask = results.get('pts_instance_mask', None)
pts_semantic_mask = results.get('pts_semantic_mask', None)
pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None:
pts_instance_mask = pts_instance_mask[choices]
results['pts_instance_mask'] = pts_instance_mask
input_dict['pts_instance_mask'] = pts_instance_mask
if pts_semantic_mask is not None:
pts_semantic_mask = pts_semantic_mask[choices]
results['pts_semantic_mask'] = pts_semantic_mask
input_dict['pts_semantic_mask'] = pts_semantic_mask
return results
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
......
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import OBJECTSAMPLERS, TRANSFORMS
__all__ = ['TRANSFORMS', 'OBJECTSAMPLERS']
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import Registry
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
OBJECTSAMPLERS = Registry('Object sampler')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment