"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "554d1cc04242bc2273630c672af5ac5f6e2883fe"
Commit a4a7bb49 authored by liyinhao's avatar liyinhao
Browse files

change get, change registry to builder

parent 4f1a5e52
import numpy as np
from mmdet.datasets.registry import PIPELINES
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
......@@ -18,8 +18,8 @@ class IndoorFlipData(object):
self.flip_ratio = flip_ratio
def __call__(self, results):
points = results.get('points', None)
gt_bboxes_3d = results.get('gt_bboxes_3d', None)
points = results['points']
gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if np.random.random() > self.flip_ratio:
# Flipping along the YZ plane
......@@ -59,7 +59,7 @@ class IndoorPointsColorJitter(object):
Default: [0.95, 1.05].
jitter_range (List[float]): Range of jittering.
Default: [-0.025, 0.025].
prob_drop (float): Probability to drop out points' color.
drop_prob (float): Probability to drop out points' color.
Default: 0.3
"""
......@@ -68,15 +68,15 @@ class IndoorPointsColorJitter(object):
bright_range=[0.8, 1.2],
color_shift_range=[0.95, 1.05],
jitter_range=[-0.025, 0.025],
prob_drop=0.3):
drop_prob=0.3):
self.color_mean = color_mean
self.bright_range = bright_range
self.color_shift_range = color_shift_range
self.jitter_range = jitter_range
self.prob_drop = prob_drop
self.drop_prob = drop_prob
def __call__(self, results):
points = results.get('points', None)
points = results['points']
assert points.shape[1] >= 6
rgb_color = points[:, 3:6] + self.color_mean
# brightness change for each channel
......@@ -91,7 +91,7 @@ class IndoorPointsColorJitter(object):
rgb_color = np.clip(rgb_color, 0, 1)
# randomly drop out points' colors
rgb_color *= np.expand_dims(
np.random.random(points.shape[0]) > self.prob_drop, -1)
np.random.random(points.shape[0]) > self.drop_prob, -1)
points[:, 3:6] = rgb_color - self.color_mean
results['points'] = points
return results
......@@ -102,7 +102,7 @@ class IndoorPointsColorJitter(object):
repr_str += '(bright_range={})'.format(self.bright_range)
repr_str += '(color_shift_range={})'.format(self.color_shift_range)
repr_str += '(jitter_range={})'.format(self.jitter_range)
repr_str += '(prob_drop={})'.format(self.prob_drop)
repr_str += '(drop_prob={})'.format(self.drop_prob)
# TODO: merge outdoor indoor transform.
......@@ -177,8 +177,8 @@ class IndoorGlobalRotScale(object):
return np.concatenate([new_centers, new_lengths], axis=1)
def __call__(self, results):
points = results.get('points', None)
gt_bboxes_3d = results.get('gt_bboxes_3d', None)
points = results['points']
gt_bboxes_3d = results['gt_bboxes_3d']
name = 'scannet' if gt_bboxes_3d.shape[1] == 6 else 'sunrgbd'
if self.rot_range is not None:
......
......@@ -21,8 +21,8 @@ def test_indoor_flip_data():
-1.58242359
]])
sunrgbd_results = sunrgbd_indoor_flip_data(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None)
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array(
[[-1.02828765, 3.65790772, 0.1972947, 1.61959505],
......@@ -47,8 +47,8 @@ def test_indoor_flip_data():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]])
scannet_results = scannet_indoor_flip_data(scannet_results)
scannet_points = scannet_results.get('points', None)
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None)
scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array(
[[-1.6110241, 0.16903955, 0.5811581, 0.5989725],
......@@ -81,8 +81,8 @@ def test_global_rot_scale():
]])
sunrgbd_results = sunrgbd_augment(sunrgbd_results)
sunrgbd_points = sunrgbd_results.get('points', None)
sunrgbd_gt_bboxes_3d = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_points = sunrgbd_results['points']
sunrgbd_gt_bboxes_3d = sunrgbd_results['gt_bboxes_3d']
expected_sunrgbd_points = np.array(
[[0.89427376, 3.94489646, 0.21003141, 1.72415094],
......@@ -113,8 +113,8 @@ def test_global_rot_scale():
-0.03226406, 1.70392646, 0.60348618, 0.65165804, 0.72084366, 0.64667457
]])
scannet_results = scannet_augment(scannet_results)
scannet_points = scannet_results.get('points', None)
scannet_gt_bboxes_3d = scannet_results.get('gt_bboxes_3d', None)
scannet_points = scannet_results['points']
scannet_gt_bboxes_3d = scannet_results['gt_bboxes_3d']
expected_scannet_points = np.array(
[[1.61240576, -0.15530836, 0.5811581, 0.5989725],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment