Commit f4f8ae22 authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

Refactor GlobalAlignment and PointSegclassMappin

parent 7c6810e3
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import mmcv
import numpy as np
from mmcv import BaseTransform
from mmcv.transforms import LoadImageFromFile
from mmcv.transforms.base import BaseTransform
from mmdet3d.core.points import BasePoints, get_points_type
from mmdet3d.registry import TRANSFORMS
......@@ -241,9 +243,19 @@ class LoadPointsFromMultiSweeps(object):
@TRANSFORMS.register_module()
class PointSegClassMapping(object):
class PointSegClassMapping(BaseTransform):
"""Map original semantic class to valid category ids.
Required Keys:
- lidar_points (dict)
- lidar_path (str)
Added Keys:
- points (np.float32)
Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids).
......@@ -253,7 +265,9 @@ class PointSegClassMapping(object):
segmentation mask. Defaults to 40.
"""
def __init__(self, valid_cat_ids, max_cat_id=40):
def __init__(self,
valid_cat_ids: Sequence[int],
max_cat_id: int = 40) -> None:
assert max_cat_id >= np.max(valid_cat_ids), \
'max_cat_id should be greater than maximum id in valid_cat_ids'
......@@ -267,7 +281,7 @@ class PointSegClassMapping(object):
for cls_idx, cat_id in enumerate(valid_cat_ids):
self.cat_id2class[cat_id] = cls_idx
def __call__(self, results):
def transform(self, results: dict) -> None:
"""Call function to map original semantic class to valid category ids.
Args:
......@@ -320,11 +334,11 @@ class NormalizePointsColor(object):
"""
points = results['points']
assert points.attribute_dims is not None and \
'color' in points.attribute_dims.keys(), \
'Expect points have color attribute'
'color' in points.attribute_dims.keys(), \
'Expect points have color attribute'
if self.color_mean is not None:
points.color = points.color - \
points.color.new_tensor(self.color_mean)
points.color.new_tensor(self.color_mean)
points.color = points.color / 255.0
results['points'] = points
return results
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import List
from typing import Dict, List
import cv2
import numpy as np
......@@ -507,7 +507,7 @@ class ObjectNoise(BaseTransform):
@TRANSFORMS.register_module()
class GlobalAlignment(object):
class GlobalAlignment(BaseTransform):
"""Apply global alignment to 3D scene points by rotation and translation.
Args:
......@@ -521,10 +521,10 @@ class GlobalAlignment(object):
bounding boxes for evaluation.
"""
def __init__(self, rotation_axis):
def __init__(self, rotation_axis: int) -> None:
self.rotation_axis = rotation_axis
def _trans_points(self, input_dict, trans_factor):
def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None:
"""Private function to translate points.
Args:
......@@ -534,9 +534,9 @@ class GlobalAlignment(object):
Returns:
dict: Results after translation, 'points' is updated in the dict.
"""
input_dict['points'].translate(trans_factor)
results['points'].translate(trans_factor)
def _rot_points(self, input_dict, rot_mat):
def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None:
"""Private function to rotate bounding boxes and points.
Args:
......@@ -547,9 +547,9 @@ class GlobalAlignment(object):
dict: Results after rotation, 'points' is updated in the dict.
"""
# input should be rot_mat_T so I transpose it here
input_dict['points'].rotate(rot_mat.T)
results['points'].rotate(rot_mat.T)
def _check_rot_mat(self, rot_mat):
def _check_rot_mat(self, rot_mat: np.ndarray) -> None:
"""Check if rotation matrix is valid for self.rotation_axis.
Args:
......@@ -562,7 +562,7 @@ class GlobalAlignment(object):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}'
def __call__(self, input_dict):
def transform(self, results: Dict) -> Dict:
"""Call function to shuffle points.
Args:
......@@ -572,20 +572,20 @@ class GlobalAlignment(object):
dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict.
"""
assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \
assert 'axis_align_matrix' in results, \
'axis_align_matrix is not provided in GlobalAlignment'
axis_align_matrix = input_dict['ann_info']['axis_align_matrix']
axis_align_matrix = results['axis_align_matrix']
assert axis_align_matrix.shape == (4, 4), \
f'invalid shape {axis_align_matrix.shape} for axis_align_matrix'
rot_mat = axis_align_matrix[:3, :3]
trans_vec = axis_align_matrix[:3, -1]
self._check_rot_mat(rot_mat)
self._rot_points(input_dict, rot_mat)
self._trans_points(input_dict, trans_vec)
self._rot_points(results, rot_mat)
self._trans_points(results, trans_vec)
return input_dict
return results
def __repr__(self):
repr_str = self.__class__.__name__
......
......@@ -2,11 +2,12 @@
import copy
import unittest
import numpy as np
import torch
from mmengine.testing import assert_allclose
from utils import create_data_info_after_loading
from mmdet3d.datasets import RandomFlip3D
from mmdet3d.datasets import GlobalAlignment, RandomFlip3D
from mmdet3d.datasets.pipelines import GlobalRotScaleTrans
......@@ -77,3 +78,24 @@ class TestRandomFlip3D(unittest.TestCase):
-ori_data_info['gt_bboxes_3d'].tensor[:, 1])
assert_allclose(data_info['gt_bboxes_3d'].tensor[:, 2],
ori_data_info['gt_bboxes_3d'].tensor[:, 2])
class TestGlobalAlignment(unittest.TestCase):
def test_global_alignment(self):
data_info = create_data_info_after_loading()
global_align_transform = GlobalAlignment(rotation_axis=2)
data_info['axis_align_matrix'] = np.array(
[[0.945519, 0.325568, 0., -5.38439],
[-0.325568, 0.945519, 0., -2.87178], [0., 0., 1., -0.06435],
[0., 0., 0., 1.]],
dtype=np.float32)
global_align_transform(data_info)
data_info['axis_align_matrix'] = np.array(
[[0.945519, 0.325568, 0., -5.38439], [0, 2, 0., -2.87178],
[0., 0., 1., -0.06435], [0., 0., 0., 1.]],
dtype=np.float32)
# assert the rot metric
with self.assertRaises(AssertionError):
global_align_transform(data_info)
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import numpy as np
import torch
from mmengine.testing import assert_allclose
from utils import create_dummy_data_info
from mmdet3d.core import DepthPoints, LiDARPoints
from mmdet3d.datasets.pipelines import PointSegClassMapping
from mmdet3d.datasets.pipelines.loading import (LoadAnnotations3D,
LoadPointsFromFile)
......@@ -71,3 +73,17 @@ class TestLoadAnnotations3D(unittest.TestCase):
self.assertIn('with_bbox_3d=True', repr_str)
self.assertIn('with_label_3d=True', repr_str)
self.assertIn('with_bbox_depth=False', repr_str)
class TestPointSegClassMapping(unittest.TestCase):
def test_point_seg_class_mapping(self):
results = dict()
results['pts_semantic_mask'] = np.array([1, 2, 3, 4, 5])
point_seg_mapping_transform = PointSegClassMapping(
valid_cat_ids=[1, 2, 3],
max_cat_id=results['pts_semantic_mask'].max())
results = point_seg_mapping_transform(results)
assert_allclose(results['pts_semantic_mask'], np.array([0, 1, 2, 3,
3]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment