Commit 55b9e78b authored by liyinhao's avatar liyinhao
Browse files

change osp to mmcv, change names of modules

parent e387ec62
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class PointsColorNormalize(object):
"""Points Color Normalize
class IndoorPointsColorNormalize(object):
"""Indoor Points Color Normalize
Normalize color of the points.
......@@ -20,7 +19,8 @@ class PointsColorNormalize(object):
def __call__(self, results):
points = results['points']
assert points.shape[1] >= 6, 'Incomplete color channel.'
assert points.shape[
1] >= 6, f'Expect points have channel >=6, got {points.shape[1]}'
points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0
results['points'] = points
return results
......@@ -32,8 +32,8 @@ class PointsColorNormalize(object):
@PIPELINES.register_module()
class LoadPointsFromFile(object):
"""Load Points From File.
class IndoorLoadPointsFromFile(object):
"""Indoor Load Points From File.
Load sunrgbd and scannet points from file.
......@@ -47,13 +47,17 @@ class LoadPointsFromFile(object):
def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]):
self.use_height = use_height
assert max(use_dim) < load_dim, 'Wrong dimension is used.'
assert max(
use_dim
) < load_dim, f'Expect all used dimensions < {load_dim}, ' \
f'got {[dim for dim in use_dim if dim >= load_dim]}'
self.load_dim = load_dim
self.use_dim = use_dim
def __call__(self, results):
pts_filename = results['pts_filename']
assert osp.exists(pts_filename), f'{pts_filename} does not exist.'
mmcv.check_file_exist(
pts_filename, msg_tmpl=f'{pts_filename} does not exist.')
points = np.load(pts_filename)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
......@@ -75,10 +79,10 @@ class LoadPointsFromFile(object):
@PIPELINES.register_module
class LoadAnnotations3D(object):
"""Load Annotations3D.
class IndoorLoadAnnotations3D(object):
"""Indoor Load Annotations3D.
Load sunrgbd and scannet annotations.
Load instance mask and semantic mask of points.
"""
def __init__(self):
......@@ -88,10 +92,12 @@ class LoadAnnotations3D(object):
pts_instance_mask_path = results['pts_instance_mask_path']
pts_semantic_mask_path = results['pts_semantic_mask_path']
assert osp.exists(pts_instance_mask_path
), f'{pts_instance_mask_path} does not exist.'
assert osp.exists(pts_semantic_mask_path
), f'{pts_semantic_mask_path} does not exist.'
mmcv.check_file_exist(
pts_instance_mask_path,
msg_tmpl=f'{pts_instance_mask_path} does not exist.')
mmcv.check_file_exist(
pts_semantic_mask_path,
msg_tmpl=f'{pts_semantic_mask_path} does not exist.')
pts_instance_mask = np.load(pts_instance_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path)
results['pts_instance_mask'] = pts_instance_mask
......
......@@ -3,13 +3,13 @@ import os.path as osp
import mmcv
import numpy as np
from mmdet3d.datasets.pipelines.indoor_loading import (LoadAnnotations3D,
LoadPointsFromFile)
from mmdet3d.datasets.pipelines.indoor_loading import ( # yapf: enable
IndoorLoadAnnotations3D, IndoorLoadPointsFromFile)
def test_load_points_from_file():
def test_indoor_load_points_from_file():
sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')
sunrgbd_load_points_from_file = LoadPointsFromFile(True, 6)
sunrgbd_load_points_from_file = IndoorLoadPointsFromFile(True, 6)
sunrgbd_results = dict()
data_path = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_info = sunrgbd_info[0]
......@@ -21,7 +21,7 @@ def test_load_points_from_file():
assert sunrgbd_point_cloud.shape == (100, 4)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
scannet_load_data = LoadPointsFromFile(True)
scannet_load_data = IndoorLoadPointsFromFile(True)
scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data'
scannet_results['data_path'] = data_path
......@@ -50,7 +50,7 @@ def test_load_annotations3D():
assert sunrgbd_gt_bboxes_3d_mask.shape == (3, 1)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
scannet_load_annotations3D = LoadAnnotations3D()
scannet_load_annotations3D = IndoorLoadAnnotations3D()
scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data'
if scannet_info['annos']['gt_num'] != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment