"vscode:/vscode.git/clone" did not exist on "2ac5586e8a6d5dd1f40da8d1a1ae288cac5fafec"
Commit f4f8ae22 authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

Refactor GlobalAlignment and PointSegclassMappin

parent 7c6810e3
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv import BaseTransform
from mmcv.transforms import LoadImageFromFile from mmcv.transforms import LoadImageFromFile
from mmcv.transforms.base import BaseTransform
from mmdet3d.core.points import BasePoints, get_points_type from mmdet3d.core.points import BasePoints, get_points_type
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
...@@ -241,9 +243,19 @@ class LoadPointsFromMultiSweeps(object): ...@@ -241,9 +243,19 @@ class LoadPointsFromMultiSweeps(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class PointSegClassMapping(object): class PointSegClassMapping(BaseTransform):
"""Map original semantic class to valid category ids. """Map original semantic class to valid category ids.
Required Keys:
- lidar_points (dict)
- lidar_path (str)
Added Keys:
- points (np.float32)
Map valid classes as 0~len(valid_cat_ids)-1 and Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids). others as len(valid_cat_ids).
...@@ -253,7 +265,9 @@ class PointSegClassMapping(object): ...@@ -253,7 +265,9 @@ class PointSegClassMapping(object):
segmentation mask. Defaults to 40. segmentation mask. Defaults to 40.
""" """
def __init__(self, valid_cat_ids, max_cat_id=40): def __init__(self,
valid_cat_ids: Sequence[int],
max_cat_id: int = 40) -> None:
assert max_cat_id >= np.max(valid_cat_ids), \ assert max_cat_id >= np.max(valid_cat_ids), \
'max_cat_id should be greater than maximum id in valid_cat_ids' 'max_cat_id should be greater than maximum id in valid_cat_ids'
...@@ -267,7 +281,7 @@ class PointSegClassMapping(object): ...@@ -267,7 +281,7 @@ class PointSegClassMapping(object):
for cls_idx, cat_id in enumerate(valid_cat_ids): for cls_idx, cat_id in enumerate(valid_cat_ids):
self.cat_id2class[cat_id] = cls_idx self.cat_id2class[cat_id] = cls_idx
def __call__(self, results): def transform(self, results: dict) -> None:
"""Call function to map original semantic class to valid category ids. """Call function to map original semantic class to valid category ids.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import List from typing import Dict, List
import cv2 import cv2
import numpy as np import numpy as np
...@@ -507,7 +507,7 @@ class ObjectNoise(BaseTransform): ...@@ -507,7 +507,7 @@ class ObjectNoise(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class GlobalAlignment(object): class GlobalAlignment(BaseTransform):
"""Apply global alignment to 3D scene points by rotation and translation. """Apply global alignment to 3D scene points by rotation and translation.
Args: Args:
...@@ -521,10 +521,10 @@ class GlobalAlignment(object): ...@@ -521,10 +521,10 @@ class GlobalAlignment(object):
bounding boxes for evaluation. bounding boxes for evaluation.
""" """
def __init__(self, rotation_axis): def __init__(self, rotation_axis: int) -> None:
self.rotation_axis = rotation_axis self.rotation_axis = rotation_axis
def _trans_points(self, input_dict, trans_factor): def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None:
"""Private function to translate points. """Private function to translate points.
Args: Args:
...@@ -534,9 +534,9 @@ class GlobalAlignment(object): ...@@ -534,9 +534,9 @@ class GlobalAlignment(object):
Returns: Returns:
dict: Results after translation, 'points' is updated in the dict. dict: Results after translation, 'points' is updated in the dict.
""" """
input_dict['points'].translate(trans_factor) results['points'].translate(trans_factor)
def _rot_points(self, input_dict, rot_mat): def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None:
"""Private function to rotate bounding boxes and points. """Private function to rotate bounding boxes and points.
Args: Args:
...@@ -547,9 +547,9 @@ class GlobalAlignment(object): ...@@ -547,9 +547,9 @@ class GlobalAlignment(object):
dict: Results after rotation, 'points' is updated in the dict. dict: Results after rotation, 'points' is updated in the dict.
""" """
# input should be rot_mat_T so I transpose it here # input should be rot_mat_T so I transpose it here
input_dict['points'].rotate(rot_mat.T) results['points'].rotate(rot_mat.T)
def _check_rot_mat(self, rot_mat): def _check_rot_mat(self, rot_mat: np.ndarray) -> None:
"""Check if rotation matrix is valid for self.rotation_axis. """Check if rotation matrix is valid for self.rotation_axis.
Args: Args:
...@@ -562,7 +562,7 @@ class GlobalAlignment(object): ...@@ -562,7 +562,7 @@ class GlobalAlignment(object):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all() is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}' assert is_valid, f'invalid rotation matrix {rot_mat}'
def __call__(self, input_dict): def transform(self, results: Dict) -> Dict:
"""Call function to shuffle points. """Call function to shuffle points.
Args: Args:
...@@ -572,20 +572,20 @@ class GlobalAlignment(object): ...@@ -572,20 +572,20 @@ class GlobalAlignment(object):
dict: Results after global alignment, 'points' and keys in dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \ assert 'axis_align_matrix' in results, \
'axis_align_matrix is not provided in GlobalAlignment' 'axis_align_matrix is not provided in GlobalAlignment'
axis_align_matrix = input_dict['ann_info']['axis_align_matrix'] axis_align_matrix = results['axis_align_matrix']
assert axis_align_matrix.shape == (4, 4), \ assert axis_align_matrix.shape == (4, 4), \
f'invalid shape {axis_align_matrix.shape} for axis_align_matrix' f'invalid shape {axis_align_matrix.shape} for axis_align_matrix'
rot_mat = axis_align_matrix[:3, :3] rot_mat = axis_align_matrix[:3, :3]
trans_vec = axis_align_matrix[:3, -1] trans_vec = axis_align_matrix[:3, -1]
self._check_rot_mat(rot_mat) self._check_rot_mat(rot_mat)
self._rot_points(input_dict, rot_mat) self._rot_points(results, rot_mat)
self._trans_points(input_dict, trans_vec) self._trans_points(results, trans_vec)
return input_dict return results
def __repr__(self): def __repr__(self):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
import copy import copy
import unittest import unittest
import numpy as np
import torch import torch
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from utils import create_data_info_after_loading from utils import create_data_info_after_loading
from mmdet3d.datasets import RandomFlip3D from mmdet3d.datasets import GlobalAlignment, RandomFlip3D
from mmdet3d.datasets.pipelines import GlobalRotScaleTrans from mmdet3d.datasets.pipelines import GlobalRotScaleTrans
...@@ -77,3 +78,24 @@ class TestRandomFlip3D(unittest.TestCase): ...@@ -77,3 +78,24 @@ class TestRandomFlip3D(unittest.TestCase):
-ori_data_info['gt_bboxes_3d'].tensor[:, 1]) -ori_data_info['gt_bboxes_3d'].tensor[:, 1])
assert_allclose(data_info['gt_bboxes_3d'].tensor[:, 2], assert_allclose(data_info['gt_bboxes_3d'].tensor[:, 2],
ori_data_info['gt_bboxes_3d'].tensor[:, 2]) ori_data_info['gt_bboxes_3d'].tensor[:, 2])
class TestGlobalAlignment(unittest.TestCase):
def test_global_alignment(self):
data_info = create_data_info_after_loading()
global_align_transform = GlobalAlignment(rotation_axis=2)
data_info['axis_align_matrix'] = np.array(
[[0.945519, 0.325568, 0., -5.38439],
[-0.325568, 0.945519, 0., -2.87178], [0., 0., 1., -0.06435],
[0., 0., 0., 1.]],
dtype=np.float32)
global_align_transform(data_info)
data_info['axis_align_matrix'] = np.array(
[[0.945519, 0.325568, 0., -5.38439], [0, 2, 0., -2.87178],
[0., 0., 1., -0.06435], [0., 0., 0., 1.]],
dtype=np.float32)
# assert the rot metric
with self.assertRaises(AssertionError):
global_align_transform(data_info)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import unittest import unittest
import numpy as np
import torch import torch
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from utils import create_dummy_data_info from utils import create_dummy_data_info
from mmdet3d.core import DepthPoints, LiDARPoints from mmdet3d.core import DepthPoints, LiDARPoints
from mmdet3d.datasets.pipelines import PointSegClassMapping
from mmdet3d.datasets.pipelines.loading import (LoadAnnotations3D, from mmdet3d.datasets.pipelines.loading import (LoadAnnotations3D,
LoadPointsFromFile) LoadPointsFromFile)
...@@ -71,3 +73,17 @@ class TestLoadAnnotations3D(unittest.TestCase): ...@@ -71,3 +73,17 @@ class TestLoadAnnotations3D(unittest.TestCase):
self.assertIn('with_bbox_3d=True', repr_str) self.assertIn('with_bbox_3d=True', repr_str)
self.assertIn('with_label_3d=True', repr_str) self.assertIn('with_label_3d=True', repr_str)
self.assertIn('with_bbox_depth=False', repr_str) self.assertIn('with_bbox_depth=False', repr_str)
class TestPointSegClassMapping(unittest.TestCase):
def test_point_seg_class_mapping(self):
results = dict()
results['pts_semantic_mask'] = np.array([1, 2, 3, 4, 5])
point_seg_mapping_transform = PointSegClassMapping(
valid_cat_ids=[1, 2, 3],
max_cat_id=results['pts_semantic_mask'].max())
results = point_seg_mapping_transform(results)
assert_allclose(results['pts_semantic_mask'], np.array([0, 1, 2, 3,
3]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment