Commit a4a7bb49 authored by liyinhao's avatar liyinhao
Browse files

change get, change registry to builder

parent 4f1a5e52
import numpy as np import numpy as np
from mmdet.datasets.registry import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module() @PIPELINES.register_module()
...@@ -18,8 +18,8 @@ class IndoorFlipData(object): ...@@ -18,8 +18,8 @@ class IndoorFlipData(object):
self.flip_ratio = flip_ratio self.flip_ratio = flip_ratio
def __call__(self, results): def __call__(self, results):
points = results.get('points', None) points = results['points']
gt_bboxes_3d = results.get('gt_bboxes_3d', None) gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd' name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if np.random.random() > self.flip_ratio: if np.random.random() > self.flip_ratio:
# Flipping along the YZ plane # Flipping along the YZ plane
...@@ -59,7 +59,7 @@ class IndoorPointsColorJitter(object): ...@@ -59,7 +59,7 @@ class IndoorPointsColorJitter(object):
Default: [0.95, 1.05]. Default: [0.95, 1.05].
jitter_range (List[float]): Range of jittering. jitter_range (List[float]): Range of jittering.
Default: [-0.025, 0.025]. Default: [-0.025, 0.025].
prob_drop (float): Probability to drop out points' color. drop_prob (float): Probability to drop out points' color.
Default: 0.3 Default: 0.3
""" """
...@@ -68,15 +68,15 @@ class IndoorPointsColorJitter(object): ...@@ -68,15 +68,15 @@ class IndoorPointsColorJitter(object):
bright_range=[0.8, 1.2], bright_range=[0.8, 1.2],
color_shift_range=[0.95, 1.05], color_shift_range=[0.95, 1.05],
jitter_range=[-0.025, 0.025], jitter_range=[-0.025, 0.025],
prob_drop=0.3): drop_prob=0.3):
self.color_mean = color_mean self.color_mean = color_mean
self.bright_range = bright_range self.bright_range = bright_range
self.color_shift_range = color_shift_range self.color_shift_range = color_shift_range
self.jitter_range = jitter_range self.jitter_range = jitter_range
self.prob_drop = prob_drop self.drop_prob = drop_prob
def __call__(self, results): def __call__(self, results):
points = results.get('points', None) points = results['points']
assert points.shape[1] >= 6 assert points.shape[1] >= 6
rgb_color = points[:, 3:6] + self.color_mean rgb_color = points[:, 3:6] + self.color_mean
# brightness change for each channel # brightness change for each channel
...@@ -91,7 +91,7 @@ class IndoorPointsColorJitter(object): ...@@ -91,7 +91,7 @@ class IndoorPointsColorJitter(object):
rgb_color = np.clip(rgb_color, 0, 1) rgb_color = np.clip(rgb_color, 0, 1)
# randomly drop out points' colors # randomly drop out points' colors
rgb_color *= np.expand_dims( rgb_color *= np.expand_dims(
np.random.random(points.shape[0]) > self.prob_drop, -1) np.random.random(points.shape[0]) > self.drop_prob, -1)
points[:, 3:6] = rgb_color - self.color_mean points[:, 3:6] = rgb_color - self.color_mean
results['points'] = points results['points'] = points
return results return results
...@@ -102,7 +102,7 @@ class IndoorPointsColorJitter(object): ...@@ -102,7 +102,7 @@ class IndoorPointsColorJitter(object):
repr_str += '(bright_range={})'.format(self.bright_range) repr_str += '(bright_range={})'.format(self.bright_range)
repr_str += '(color_shift_range={})'.format(self.color_shift_range) repr_str += '(color_shift_range={})'.format(self.color_shift_range)
repr_str += '(jitter_range={})'.format(self.jitter_range) repr_str += '(jitter_range={})'.format(self.jitter_range)
repr_str += '(prob_drop={})'.format(self.prob_drop) repr_str += '(drop_prob={})'.format(self.drop_prob)
# TODO: merge outdoor indoor transform. # TODO: merge outdoor indoor transform.
...@@ -177,8 +177,8 @@ class IndoorGlobalRotScale(object): ...@@ -177,8 +177,8 @@ class IndoorGlobalRotScale(object):
return np.concatenate([new_centers, new_lengths], axis=1) return np.concatenate([new_centers, new_lengths], axis=1)
def __call__(self, results): def __call__(self, results):
points = results.get('points', None) points = results['points']
gt_bboxes_3d = results.get('gt_bboxes_3d', None) gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd' name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if self.rot_range is not None: if self.rot_range is not None:
......
...@@ -21,8 +21,8 @@ def test_indoor_flip_data(): ...@@ -21,8 +21,8 @@ def test_indoor_flip_data():
-1.58242359 -1.58242359
]]) ]])
sunrgbd_results = sunrgbd_indoor_flip_data(sunrgbd_results) sunrgbd_results = sunrgbd_indoor_flip_data(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None) sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None) sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array( expected_sunrgbd_points = np.array(
[[-1.02828765, 3.65790772, 0.1972947, 1.61959505], [[-1.02828765, 3.65790772, 0.1972947, 1.61959505],
...@@ -47,8 +47,8 @@ def test_indoor_flip_data(): ...@@ -47,8 +47,8 @@ def test_indoor_flip_data():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457 -0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]]) ]])
scannet_results = scannet_indoor_flip_data(scannet_results) scannet_results = scannet_indoor_flip_data(scannet_results)
scannet_points = scannet_results.get('points', None) scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None) scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array( expected_scannet_points = np.array(
[[-1.6110241, 0.16903955, 0.5811581, 0.5989725], [[-1.6110241, 0.16903955, 0.5811581, 0.5989725],
...@@ -81,8 +81,8 @@ def test_global_rot_scale(): ...@@ -81,8 +81,8 @@ def test_global_rot_scale():
]]) ]])
sunrgbd_results = sunrgbd_augment(sunrgbd_results) sunrgbd_results = sunrgbd_augment(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None) sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None) sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array( expected_sunrgbd_points = np.array(
[[0.89427376, 3.94489646, 0.21003141, 1.72415094], [[0.89427376, 3.94489646, 0.21003141, 1.72415094],
...@@ -113,8 +113,8 @@ def test_global_rot_scale(): ...@@ -113,8 +113,8 @@ def test_global_rot_scale():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457 -0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]]) ]])
scannet_results = scannet_augment(scannet_results) scannet_results = scannet_augment(scannet_results)
scannet_points = scannet_results.get('points', None) scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None) scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array( expected_scannet_points = np.array(
[[1.61240576, -0.15530836, 0.5811581, 0.5989725], [[1.61240576, -0.15530836, 0.5811581, 0.5989725],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment