Commit 49a2bc85 authored by liyinhao's avatar liyinhao Committed by zhangwenwei
Browse files

Change data converter

parent c42ad958
...@@ -42,7 +42,7 @@ class Custom3DDataset(Dataset): ...@@ -42,7 +42,7 @@ class Custom3DDataset(Dataset):
def get_data_info(self, index): def get_data_info(self, index):
info = self.data_infos[index] info = self.data_infos[index]
sample_idx = info['point_cloud']['lidar_idx'] sample_idx = info['point_cloud']['lidar_idx']
pts_filename = self._get_pts_filename(sample_idx) pts_filename = osp.join(self.data_root, info['pts_path'])
input_dict = dict( input_dict = dict(
pts_filename=pts_filename, pts_filename=pts_filename,
......
...@@ -143,7 +143,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -143,7 +143,7 @@ class LoadAnnotations3D(LoadAnnotations):
def _load_masks_3d(self, results): def _load_masks_3d(self, results):
pts_instance_mask_path = results['ann_info']['pts_instance_mask_path'] pts_instance_mask_path = results['ann_info']['pts_instance_mask_path']
mmcv.check_file_exist(pts_instance_mask_path) mmcv.check_file_exist(pts_instance_mask_path)
pts_instance_mask = np.load(pts_instance_mask_path).astype(np.int) pts_instance_mask = np.fromfile(pts_instance_mask_path, dtype=np.long)
results['pts_instance_mask'] = pts_instance_mask results['pts_instance_mask'] = pts_instance_mask
results['pts_mask_fields'].append(results['pts_instance_mask']) results['pts_mask_fields'].append(results['pts_instance_mask'])
return results return results
...@@ -151,7 +151,7 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -151,7 +151,7 @@ class LoadAnnotations3D(LoadAnnotations):
def _load_semantic_seg_3d(self, results): def _load_semantic_seg_3d(self, results):
pts_semantic_mask_path = results['ann_info']['pts_semantic_mask_path'] pts_semantic_mask_path = results['ann_info']['pts_semantic_mask_path']
mmcv.check_file_exist(pts_semantic_mask_path) mmcv.check_file_exist(pts_semantic_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.int) pts_semantic_mask = np.fromfile(pts_semantic_mask_path, dtype=np.long)
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = pts_semantic_mask
results['pts_seg_fields'].append(results['pts_semantic_mask']) results['pts_seg_fields'].append(results['pts_semantic_mask'])
return results return results
......
...@@ -19,12 +19,10 @@ class ScanNetDataset(Custom3DDataset): ...@@ -19,12 +19,10 @@ class ScanNetDataset(Custom3DDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
classes=None, classes=None,
modality=None,
test_mode=False): test_mode=False):
super().__init__(data_root, ann_file, pipeline, classes, test_mode) super().__init__(data_root, ann_file, pipeline, classes, modality,
test_mode)
def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_root, f'{sample_idx}_vert.npy')
return pts_filename
def get_ann_info(self, index): def get_ann_info(self, index):
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
...@@ -36,11 +34,10 @@ class ScanNetDataset(Custom3DDataset): ...@@ -36,11 +34,10 @@ class ScanNetDataset(Custom3DDataset):
else: else:
gt_bboxes_3d = np.zeros((0, 6), dtype=np.float32) gt_bboxes_3d = np.zeros((0, 6), dtype=np.float32)
gt_labels_3d = np.zeros((0, ), dtype=np.long) gt_labels_3d = np.zeros((0, ), dtype=np.long)
sample_idx = info['point_cloud']['lidar_idx']
pts_instance_mask_path = osp.join(self.data_root, pts_instance_mask_path = osp.join(self.data_root,
f'{sample_idx}_ins_label.npy') info['pts_instance_mask_path'])
pts_semantic_mask_path = osp.join(self.data_root, pts_semantic_mask_path = osp.join(self.data_root,
f'{sample_idx}_sem_label.npy') info['pts_semantic_mask_path'])
anns_results = dict( anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
......
import os.path as osp
import numpy as np import numpy as np
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
...@@ -17,13 +15,10 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -17,13 +15,10 @@ class SUNRGBDDataset(Custom3DDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
classes=None, classes=None,
modality=None,
test_mode=False): test_mode=False):
super().__init__(data_root, ann_file, pipeline, classes, test_mode) super().__init__(data_root, ann_file, pipeline, classes, modality,
test_mode)
def _get_pts_filename(self, sample_idx):
pts_filename = osp.join(self.data_root, 'lidar',
f'{sample_idx:06d}.npy')
return pts_filename
def get_ann_info(self, index): def get_ann_info(self, index):
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
......
...@@ -27,7 +27,7 @@ def test_pointnet2_sa_ssg(): ...@@ -27,7 +27,7 @@ def test_pointnet2_sa_ssg():
assert self.FP_modules[0].mlps.layer0.conv.out_channels == 16 assert self.FP_modules[0].mlps.layer0.conv.out_channels == 16
assert self.FP_modules[1].mlps.layer0.conv.in_channels == 19 assert self.FP_modules[1].mlps.layer0.conv.in_channels == 19
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy') xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32)
xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6) xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6)
# test forward # test forward
ret_dict = self(xyz) ret_dict = self(xyz)
......
...@@ -6,7 +6,7 @@ from mmdet3d.datasets import ScanNetDataset ...@@ -6,7 +6,7 @@ from mmdet3d.datasets import ScanNetDataset
def test_getitem(): def test_getitem():
np.random.seed(0) np.random.seed(0)
root_path = './tests/data/scannet/scannet_train_instance_data' root_path = './tests/data/scannet/'
ann_file = './tests/data/scannet/scannet_infos.pkl' ann_file = './tests/data/scannet/scannet_infos.pkl'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk', 'window', 'bookshelf', 'picture', 'counter', 'desk',
...@@ -56,7 +56,7 @@ def test_getitem(): ...@@ -56,7 +56,7 @@ def test_getitem():
rot_angle = data['img_meta']._data['rot_angle'] rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx'] sample_idx = data['img_meta']._data['sample_idx']
assert file_name == './tests/data/scannet/' \ assert file_name == './tests/data/scannet/' \
'scannet_train_instance_data/scene0000_00_vert.npy' 'points/scene0000_00.bin'
assert flip_xz is True assert flip_xz is True
assert flip_yz is True assert flip_yz is True
assert abs(rot_angle - (-0.005471397477913809)) < 1e-5 assert abs(rot_angle - (-0.005471397477913809)) < 1e-5
......
...@@ -6,7 +6,7 @@ from mmdet3d.datasets import SUNRGBDDataset ...@@ -6,7 +6,7 @@ from mmdet3d.datasets import SUNRGBDDataset
def test_getitem(): def test_getitem():
np.random.seed(0) np.random.seed(0)
root_path = './tests/data/sunrgbd/sunrgbd_trainval' root_path = './tests/data/sunrgbd'
ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl' ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub') 'dresser', 'night_stand', 'bookshelf', 'bathtub')
...@@ -45,19 +45,18 @@ def test_getitem(): ...@@ -45,19 +45,18 @@ def test_getitem():
scale_ratio = data['img_meta']._data['scale_ratio'] scale_ratio = data['img_meta']._data['scale_ratio']
rot_angle = data['img_meta']._data['rot_angle'] rot_angle = data['img_meta']._data['rot_angle']
sample_idx = data['img_meta']._data['sample_idx'] sample_idx = data['img_meta']._data['sample_idx']
assert file_name == './tests/data/sunrgbd/sunrgbd_trainval' \ assert file_name == './tests/data/sunrgbd' \
'/lidar/000001.npy' '/points/000001.bin'
assert flip_xz is False assert flip_xz is False
assert flip_yz is True assert flip_yz is True
assert abs(scale_ratio - 1.0308290128214932) < 1e-5 assert abs(scale_ratio - 1.0308290128214932) < 1e-5
assert abs(rot_angle - 0.22534577750874518) < 1e-5 assert abs(rot_angle - 0.22534577750874518) < 1e-5
assert sample_idx == 1 assert sample_idx == 1
expected_points = np.array( expected_points = np.array([[0.6512, 1.5781, 0.0710, 0.0499],
[[0.6570105, 1.5538014, 0.24514851, 1.0165423], [0.6473, 1.5701, 0.0657, 0.0447],
[0.656101, 1.558591, 0.21755838, 0.98895216], [0.6464, 1.5635, 0.0826, 0.0616],
[0.6293659, 1.5679953, -0.10004003, 0.67135376], [0.6453, 1.5603, 0.0849, 0.0638],
[0.6068739, 1.5974995, -0.41063973, 0.36075398], [0.6488, 1.5786, 0.0461, 0.0251]])
[0.6464709, 1.5573514, 0.15114647, 0.9225402]])
expected_gt_bboxes_3d = np.array([[ expected_gt_bboxes_3d = np.array([[
-2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352, -2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352,
1.2205974 1.2205974
...@@ -75,7 +74,7 @@ def test_getitem(): ...@@ -75,7 +74,7 @@ def test_getitem():
expected_gt_labels = np.array([0, 7, 6]) expected_gt_labels = np.array([0, 7, 6])
original_classes = sunrgbd_dataset.CLASSES original_classes = sunrgbd_dataset.CLASSES
assert np.allclose(points, expected_points) assert np.allclose(points, expected_points, 1e-2)
assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d) assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
assert np.all(gt_labels_3d.numpy() == expected_gt_labels) assert np.all(gt_labels_3d.numpy() == expected_gt_labels)
assert original_classes == class_names assert original_classes == class_names
......
...@@ -43,22 +43,20 @@ def test_scannet_pipeline(): ...@@ -43,22 +43,20 @@ def test_scannet_pipeline():
pipeline = Compose(pipelines) pipeline = Compose(pipelines)
info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0] info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
results = dict() results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data' data_path = './tests/data/scannet'
results['data_path'] = data_path results['pts_filename'] = osp.join(data_path, info['pts_path'])
scan_name = info['point_cloud']['lidar_idx']
results['pts_filename'] = osp.join(data_path, f'{scan_name}_vert.npy')
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
scannet_gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] scannet_gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
scannet_gt_labels_3d = info['annos']['class'] np.float32)
scannet_gt_labels_3d = info['annos']['class'].astype(np.long)
else: else:
scannet_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32) scannet_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
scannet_gt_labels_3d = np.zeros((1, )) scannet_gt_labels_3d = np.zeros((1, ), dtype=np.long)
scan_name = info['point_cloud']['lidar_idx']
results['ann_info'] = dict() results['ann_info'] = dict()
results['ann_info']['pts_instance_mask_path'] = osp.join( results['ann_info']['pts_instance_mask_path'] = osp.join(
data_path, f'{scan_name}_ins_label.npy') data_path, info['pts_instance_mask_path'])
results['ann_info']['pts_semantic_mask_path'] = osp.join( results['ann_info']['pts_semantic_mask_path'] = osp.join(
data_path, f'{scan_name}_sem_label.npy') data_path, info['pts_semantic_mask_path'])
results['ann_info']['gt_bboxes_3d'] = scannet_gt_bboxes_3d results['ann_info']['gt_bboxes_3d'] = scannet_gt_bboxes_3d
results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d
...@@ -124,17 +122,16 @@ def test_sunrgbd_pipeline(): ...@@ -124,17 +122,16 @@ def test_sunrgbd_pipeline():
pipeline = Compose(pipelines) pipeline = Compose(pipelines)
results = dict() results = dict()
info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')[0] info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')[0]
data_path = './tests/data/sunrgbd/sunrgbd_trainval' data_path = './tests/data/sunrgbd'
scan_name = info['point_cloud']['lidar_idx'] results['pts_filename'] = osp.join(data_path, info['pts_path'])
results['pts_filename'] = osp.join(data_path, 'lidar',
f'{scan_name:06d}.npy')
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'] gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
gt_labels_3d = info['annos']['class'] np.float32)
gt_labels_3d = info['annos']['class'].astype(np.long)
else: else:
gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32) gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_labels_3d = np.zeros((1, )) gt_labels_3d = np.zeros((1, ), dtype=np.long)
# prepare input of pipeline # prepare input of pipeline
results['ann_info'] = dict() results['ann_info'] = dict()
...@@ -148,12 +145,11 @@ def test_sunrgbd_pipeline(): ...@@ -148,12 +145,11 @@ def test_sunrgbd_pipeline():
points = results['points']._data points = results['points']._data
gt_bboxes_3d = results['gt_bboxes_3d']._data gt_bboxes_3d = results['gt_bboxes_3d']._data
gt_labels_3d = results['gt_labels_3d']._data gt_labels_3d = results['gt_labels_3d']._data
expected_points = np.array( expected_points = np.array([[0.6512, 1.5781, 0.0710, 0.0499],
[[0.6570105, 1.5538014, 0.24514851, 1.0165423], [0.6473, 1.5701, 0.0657, 0.0447],
[0.656101, 1.558591, 0.21755838, 0.98895216], [0.6464, 1.5635, 0.0826, 0.0616],
[0.6293659, 1.5679953, -0.10004003, 0.67135376], [0.6453, 1.5603, 0.0849, 0.0638],
[0.6068739, 1.5974995, -0.41063973, 0.36075398], [0.6488, 1.5786, 0.0461, 0.0251]])
[0.6464709, 1.5573514, 0.15114647, 0.9225402]])
expected_gt_bboxes_3d = np.array([[ expected_gt_bboxes_3d = np.array([[
-2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352, -2.012483, 3.9473376, -0.25446942, 2.3730404, 1.9457763, 2.0303352,
1.2205974 1.2205974
...@@ -171,4 +167,4 @@ def test_sunrgbd_pipeline(): ...@@ -171,4 +167,4 @@ def test_sunrgbd_pipeline():
expected_gt_labels_3d = np.array([0, 7, 6]) expected_gt_labels_3d = np.array([0, 7, 6])
assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d) assert np.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
assert np.allclose(gt_labels_3d.flatten(), expected_gt_labels_3d) assert np.allclose(gt_labels_3d.flatten(), expected_gt_labels_3d)
assert np.allclose(points, expected_points) assert np.allclose(points, expected_points, 1e-2)
...@@ -11,11 +11,10 @@ def test_load_points_from_indoor_file(): ...@@ -11,11 +11,10 @@ def test_load_points_from_indoor_file():
sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl') sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')
sunrgbd_load_points_from_file = LoadPointsFromFile(6, shift_height=True) sunrgbd_load_points_from_file = LoadPointsFromFile(6, shift_height=True)
sunrgbd_results = dict() sunrgbd_results = dict()
data_path = './tests/data/sunrgbd/sunrgbd_trainval' data_path = './tests/data/sunrgbd'
sunrgbd_info = sunrgbd_info[0] sunrgbd_info = sunrgbd_info[0]
scan_name = sunrgbd_info['point_cloud']['lidar_idx'] sunrgbd_results['pts_filename'] = osp.join(data_path,
sunrgbd_results['pts_filename'] = osp.join(data_path, 'lidar', sunrgbd_info['pts_path'])
f'{scan_name:06d}.npy')
sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results) sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results)
sunrgbd_point_cloud = sunrgbd_results['points'] sunrgbd_point_cloud = sunrgbd_results['points']
assert sunrgbd_point_cloud.shape == (100, 4) assert sunrgbd_point_cloud.shape == (100, 4)
...@@ -23,13 +22,11 @@ def test_load_points_from_indoor_file(): ...@@ -23,13 +22,11 @@ def test_load_points_from_indoor_file():
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl') scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
scannet_load_data = LoadPointsFromFile(shift_height=True) scannet_load_data = LoadPointsFromFile(shift_height=True)
scannet_results = dict() scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data' data_path = './tests/data/scannet'
scannet_results['data_path'] = data_path
scannet_info = scannet_info[0] scannet_info = scannet_info[0]
scan_name = scannet_info['point_cloud']['lidar_idx']
scannet_results['pts_filename'] = osp.join(data_path, scannet_results['pts_filename'] = osp.join(data_path,
f'{scan_name}_vert.npy') scannet_info['pts_path'])
scannet_results = scannet_load_data(scannet_results) scannet_results = scannet_load_data(scannet_results)
scannet_point_cloud = scannet_results['points'] scannet_point_cloud = scannet_results['points']
assert scannet_point_cloud.shape == (100, 4) assert scannet_point_cloud.shape == (100, 4)
...@@ -67,7 +64,7 @@ def test_load_annotations3D(): ...@@ -67,7 +64,7 @@ def test_load_annotations3D():
with_mask_3d=True, with_mask_3d=True,
with_seg_3d=True) with_seg_3d=True)
scannet_results = dict() scannet_results = dict()
data_path = './tests/data/scannet/scannet_train_instance_data' data_path = './tests/data/scannet'
if scannet_info['annos']['gt_num'] != 0: if scannet_info['annos']['gt_num'] != 0:
scannet_gt_bboxes_3d = scannet_info['annos']['gt_boxes_upright_depth'] scannet_gt_bboxes_3d = scannet_info['annos']['gt_boxes_upright_depth']
...@@ -77,12 +74,11 @@ def test_load_annotations3D(): ...@@ -77,12 +74,11 @@ def test_load_annotations3D():
scannet_gt_labels_3d = np.zeros((1, )) scannet_gt_labels_3d = np.zeros((1, ))
# prepare input of loading pipeline # prepare input of loading pipeline
scan_name = scannet_info['point_cloud']['lidar_idx']
scannet_results['ann_info'] = dict() scannet_results['ann_info'] = dict()
scannet_results['ann_info']['pts_instance_mask_path'] = osp.join( scannet_results['ann_info']['pts_instance_mask_path'] = osp.join(
data_path, f'{scan_name}_ins_label.npy') data_path, scannet_info['pts_instance_mask_path'])
scannet_results['ann_info']['pts_semantic_mask_path'] = osp.join( scannet_results['ann_info']['pts_semantic_mask_path'] = osp.join(
data_path, f'{scan_name}_sem_label.npy') data_path, scannet_info['pts_semantic_mask_path'])
scannet_results['ann_info']['gt_bboxes_3d'] = scannet_gt_bboxes_3d scannet_results['ann_info']['gt_bboxes_3d'] = scannet_gt_bboxes_3d
scannet_results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d scannet_results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d
......
import numpy as np import numpy as np
import pytest
import torch import torch
def test_pointnet_sa_module_msg(): def test_pointnet_sa_module_msg():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import PointSAModuleMSG from mmdet3d.ops import PointSAModuleMSG
self = PointSAModuleMSG( self = PointSAModuleMSG(
...@@ -19,7 +22,7 @@ def test_pointnet_sa_module_msg(): ...@@ -19,7 +22,7 @@ def test_pointnet_sa_module_msg():
assert self.mlps[1].layer0.conv.in_channels == 12 assert self.mlps[1].layer0.conv.in_channels == 12
assert self.mlps[1].layer0.conv.out_channels == 32 assert self.mlps[1].layer0.conv.out_channels == 32
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy') xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
# (B, N, 3) # (B, N, 3)
xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda() xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
...@@ -34,6 +37,8 @@ def test_pointnet_sa_module_msg(): ...@@ -34,6 +37,8 @@ def test_pointnet_sa_module_msg():
def test_pointnet_sa_module(): def test_pointnet_sa_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import PointSAModule from mmdet3d.ops import PointSAModule
self = PointSAModule( self = PointSAModule(
...@@ -48,7 +53,7 @@ def test_pointnet_sa_module(): ...@@ -48,7 +53,7 @@ def test_pointnet_sa_module():
assert self.mlps[0].layer0.conv.in_channels == 15 assert self.mlps[0].layer0.conv.in_channels == 15
assert self.mlps[0].layer0.conv.out_channels == 32 assert self.mlps[0].layer0.conv.out_channels == 32
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy') xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
# (B, N, 3) # (B, N, 3)
xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda() xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
...@@ -63,13 +68,16 @@ def test_pointnet_sa_module(): ...@@ -63,13 +68,16 @@ def test_pointnet_sa_module():
def test_pointnet_fp_module(): def test_pointnet_fp_module():
if not torch.cuda.is_available():
pytest.skip()
from mmdet3d.ops import PointFPModule from mmdet3d.ops import PointFPModule
self = PointFPModule(mlp_channels=[24, 16]).cuda() self = PointFPModule(mlp_channels=[24, 16]).cuda()
assert self.mlps.layer0.conv.in_channels == 24 assert self.mlps.layer0.conv.in_channels == 24
assert self.mlps.layer0.conv.out_channels == 16 assert self.mlps.layer0.conv.out_channels == 16
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy') xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin',
np.float32).reshape((-1, 6))
# (B, N, 3) # (B, N, 3)
xyz1 = torch.from_numpy(xyz[0::2, :3]).view(1, -1, 3).cuda() xyz1 = torch.from_numpy(xyz[0::2, :3]).view(1, -1, 3).cuda()
......
...@@ -44,12 +44,14 @@ def nuscenes_data_prep(root_path, ...@@ -44,12 +44,14 @@ def nuscenes_data_prep(root_path,
'{}/{}_infos_train.pkl'.format(out_dir, info_prefix)) '{}/{}_infos_train.pkl'.format(out_dir, info_prefix))
def scannet_data_prep(root_path, info_prefix, out_dir): def scannet_data_prep(root_path, info_prefix, out_dir, workers):
indoor.create_indoor_info_file(root_path, info_prefix, out_dir) indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
def sunrgbd_data_prep(root_path, info_prefix, out_dir): def sunrgbd_data_prep(root_path, info_prefix, out_dir, workers):
indoor.create_indoor_info_file(root_path, info_prefix, out_dir) indoor.create_indoor_info_file(
root_path, info_prefix, out_dir, workers=workers)
parser = argparse.ArgumentParser(description='Data converter arg parser') parser = argparse.ArgumentParser(description='Data converter arg parser')
...@@ -78,6 +80,8 @@ parser.add_argument( ...@@ -78,6 +80,8 @@ parser.add_argument(
required='False', required='False',
help='name of info pkl') help='name of info pkl')
parser.add_argument('--extra-tag', type=str, default='kitti') parser.add_argument('--extra-tag', type=str, default='kitti')
parser.add_argument(
'--workers', type=int, default=4, help='number of threads to be used')
args = parser.parse_args() args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
...@@ -117,9 +121,11 @@ if __name__ == '__main__': ...@@ -117,9 +121,11 @@ if __name__ == '__main__':
scannet_data_prep( scannet_data_prep(
root_path=args.root_path, root_path=args.root_path,
info_prefix=args.extra_tag, info_prefix=args.extra_tag,
out_dir=args.out_dir) out_dir=args.out_dir,
workers=args.workers)
elif args.dataset == 'sunrgbd': elif args.dataset == 'sunrgbd':
sunrgbd_data_prep( sunrgbd_data_prep(
root_path=args.root_path, root_path=args.root_path,
info_prefix=args.extra_tag, info_prefix=args.extra_tag,
out_dir=args.out_dir) out_dir=args.out_dir,
workers=args.workers)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment