"git@developer.sourcefind.cn:chenzk/alphafold2_jax.git" did not exist on "0be2b30b98f0da7aecb973bde04758fae67eb913"
Commit 55b9e78b authored by liyinhao's avatar liyinhao
Browse files

change osp to mmcv, change names of modules

parent e387ec62
import os.path as osp import mmcv
import numpy as np import numpy as np
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module() @PIPELINES.register_module()
class PointsColorNormalize(object): class IndoorPointsColorNormalize(object):
"""Points Color Normalize """Indoor Points Color Normalize
Normalize color of the points. Normalize color of the points.
...@@ -20,7 +19,8 @@ class PointsColorNormalize(object): ...@@ -20,7 +19,8 @@ class PointsColorNormalize(object):
def __call__(self, results): def __call__(self, results):
points = results['points'] points = results['points']
assert points.shape[1] >= 6, 'Incomplete color channel.' assert points.shape[
1] >= 6, f'Expect points have channel >=6, got {points.shape[1]}'
points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0 points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0
results['points'] = points results['points'] = points
return results return results
...@@ -32,8 +32,8 @@ class PointsColorNormalize(object): ...@@ -32,8 +32,8 @@ class PointsColorNormalize(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadPointsFromFile(object): class IndoorLoadPointsFromFile(object):
"""Load Points From File. """Indoor Load Points From File.
Load sunrgbd and scannet points from file. Load sunrgbd and scannet points from file.
...@@ -47,13 +47,17 @@ class LoadPointsFromFile(object): ...@@ -47,13 +47,17 @@ class LoadPointsFromFile(object):
def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]): def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]):
self.use_height = use_height self.use_height = use_height
assert max(use_dim) < load_dim, 'Wrong dimension is used.' assert max(
use_dim
) < load_dim, f'Expect all used dimensions < {load_dim}, ' \
f'got {[dim for dim in use_dim if dim >= load_dim]}'
self.load_dim = load_dim self.load_dim = load_dim
self.use_dim = use_dim self.use_dim = use_dim
def __call__(self, results): def __call__(self, results):
pts_filename = results['pts_filename'] pts_filename = results['pts_filename']
assert osp.exists(pts_filename), f'{pts_filename} does not exist.' mmcv.check_file_exist(
pts_filename, msg_tmpl=f'{pts_filename} does not exist.')
points = np.load(pts_filename) points = np.load(pts_filename)
points = points.reshape(-1, self.load_dim) points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim] points = points[:, self.use_dim]
...@@ -75,10 +79,10 @@ class LoadPointsFromFile(object): ...@@ -75,10 +79,10 @@ class LoadPointsFromFile(object):
@PIPELINES.register_module @PIPELINES.register_module
class LoadAnnotations3D(object): class IndoorLoadAnnotations3D(object):
"""Load Annotations3D. """Indoor Load Annotations3D.
Load sunrgbd and scannet annotations. Load instance mask and semantic mask of points.
""" """
def __init__(self): def __init__(self):
...@@ -88,10 +92,12 @@ class LoadAnnotations3D(object): ...@@ -88,10 +92,12 @@ class LoadAnnotations3D(object):
pts_instance_mask_path = results['pts_instance_mask_path'] pts_instance_mask_path = results['pts_instance_mask_path']
pts_semantic_mask_path = results['pts_semantic_mask_path'] pts_semantic_mask_path = results['pts_semantic_mask_path']
assert osp.exists(pts_instance_mask_path mmcv.check_file_exist(
), f'{pts_instance_mask_path} does not exist.' pts_instance_mask_path,
assert osp.exists(pts_semantic_mask_path msg_tmpl=f'{pts_instance_mask_path} does not exist.')
), f'{pts_semantic_mask_path} does not exist.' mmcv.check_file_exist(
pts_semantic_mask_path,
msg_tmpl=f'{pts_semantic_mask_path} does not exist.')
pts_instance_mask = np.load(pts_instance_mask_path) pts_instance_mask = np.load(pts_instance_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path) pts_semantic_mask = np.load(pts_semantic_mask_path)
results['pts_instance_mask'] = pts_instance_mask results['pts_instance_mask'] = pts_instance_mask
......
...@@ -3,13 +3,13 @@ import os.path as osp ...@@ -3,13 +3,13 @@ import os.path as osp
import mmcv import mmcv
import numpy as np import numpy as np
from mmdet3d.datasets.pipelines.indoor_loading import (LoadAnnotations3D, from mmdet3d.datasets.pipelines.indoor_loading import ( # yapf: enable
LoadPointsFromFile) IndoorLoadAnnotations3D, IndoorLoadPointsFromFile)
def test_load_points_from_file(): def test_indoor_load_points_from_file():
sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl') sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')
sunrgbd_load_points_from_file = LoadPointsFromFile(True, 6) sunrgbd_load_points_from_file = IndoorLoadPointsFromFile(True, 6)
sunrgbd_results = dict() sunrgbd_results = dict()
data_path = './tests/data/sunrgbd/sunrgbd_trainval' data_path = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_info = sunrgbd_info[0] sunrgbd_info = sunrgbd_info[0]
...@@ -21,7 +21,7 @@ def test_load_points_from_file(): ...@@ -21,7 +21,7 @@ def test_load_points_from_file():
assert sunrgbd_point_cloud.shape == (100, 4) assert sunrgbd_point_cloud.shape == (100, 4)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl') scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
scannet_load_data = LoadPointsFromFile(True) scannet_load_data = IndoorLoadPointsFromFile(True)
scannet_results = dict() scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data' data_path = './tests/data/scannet/scannet_train_instance_data'
scannet_results['data_path'] = data_path scannet_results['data_path'] = data_path
...@@ -50,7 +50,7 @@ def test_load_annotations3D(): ...@@ -50,7 +50,7 @@ def test_load_annotations3D():
assert sunrgbd_gt_bboxes_3d_mask.shape == (3, 1) assert sunrgbd_gt_bboxes_3d_mask.shape == (3, 1)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0] scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
scannet_load_annotations3D = LoadAnnotations3D() scannet_load_annotations3D = IndoorLoadAnnotations3D()
scannet_results = dict() scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data' data_path = './tests/data/scannet/scannet_train_instance_data'
if scannet_info['annos']['gt_num'] != 0: if scannet_info['annos']['gt_num'] != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment