Commit a4a7bb49 authored by liyinhao's avatar liyinhao
Browse files

change get, change registry to builder

parent 4f1a5e52
import numpy as np
from mmdet.datasets.registry import PIPELINES
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
......@@ -18,8 +18,8 @@ class IndoorFlipData(object):
self.flip_ratio = flip_ratio
def __call__(self, results):
points = results.get('points', None)
gt_bboxes_3d = results.get('gt_bboxes_3d', None)
points = results['points']
gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if np.random.random() > self.flip_ratio:
# Flipping along the YZ plane
......@@ -59,7 +59,7 @@ class IndoorPointsColorJitter(object):
Default: [0.95, 1.05].
jitter_range (List[float]): Range of jittering.
Default: [-0.025, 0.025].
prob_drop (float): Probability to drop out points' color.
drop_prob (float): Probability to drop out points' color.
Default: 0.3
"""
......@@ -68,15 +68,15 @@ class IndoorPointsColorJitter(object):
bright_range=[0.8, 1.2],
color_shift_range=[0.95, 1.05],
jitter_range=[-0.025, 0.025],
prob_drop=0.3):
drop_prob=0.3):
self.color_mean = color_mean
self.bright_range = bright_range
self.color_shift_range = color_shift_range
self.jitter_range = jitter_range
self.prob_drop = prob_drop
self.drop_prob = drop_prob
def __call__(self, results):
points = results.get('points', None)
points = results['points']
assert points.shape[1] >= 6
rgb_color = points[:, 3:6] + self.color_mean
# brightness change for each channel
......@@ -91,7 +91,7 @@ class IndoorPointsColorJitter(object):
rgb_color = np.clip(rgb_color, 0, 1)
# randomly drop out points' colors
rgb_color *= np.expand_dims(
np.random.random(points.shape[0]) > self.prob_drop, -1)
np.random.random(points.shape[0]) > self.drop_prob, -1)
points[:, 3:6] = rgb_color - self.color_mean
results['points'] = points
return results
......@@ -102,7 +102,7 @@ class IndoorPointsColorJitter(object):
repr_str += '(bright_range={})'.format(self.bright_range)
repr_str += '(color_shift_range={})'.format(self.color_shift_range)
repr_str += '(jitter_range={})'.format(self.jitter_range)
repr_str += '(prob_drop={})'.format(self.prob_drop)
repr_str += '(drop_prob={})'.format(self.drop_prob)
# TODO: merge outdoor indoor transform.
......@@ -177,8 +177,8 @@ class IndoorGlobalRotScale(object):
return np.concatenate([new_centers, new_lengths], axis=1)
def __call__(self, results):
points = results.get('points', None)
gt_bboxes_3d = results.get('gt_bboxes_3d', None)
points = results['points']
gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if self.rot_range is not None:
......
......@@ -21,8 +21,8 @@ def test_indoor_flip_data():
-1.58242359
]])
sunrgbd_results = sunrgbd_indoor_flip_data(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None)
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array(
[[-1.02828765, 3.65790772, 0.1972947, 1.61959505],
......@@ -47,8 +47,8 @@ def test_indoor_flip_data():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]])
scannet_results = scannet_indoor_flip_data(scannet_results)
scannet_points = scannet_results.get('points', None)
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None)
scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array(
[[-1.6110241, 0.16903955, 0.5811581, 0.5989725],
......@@ -81,8 +81,8 @@ def test_global_rot_scale():
]])
sunrgbd_results = sunrgbd_augment(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None)
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array(
[[0.89427376, 3.94489646, 0.21003141, 1.72415094],
......@@ -113,8 +113,8 @@ def test_global_rot_scale():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]])
scannet_results = scannet_augment(scannet_results)
scannet_points = scannet_results.get('points', None)
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None)
scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array(
[[1.61240576, -0.15530836, 0.5811581, 0.5989725],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment