"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "930746d93c1b3ef96a510b6f88284583f8fdb766"
Commit 6060ff80 authored by liyinhao's avatar liyinhao
Browse files

add pts masks to the DC, remove __init__.py

parent fd2e572e
......@@ -40,7 +40,8 @@ class DefaultFormatBundle(object):
results['img'] = DC(to_tensor(img), stack=True)
for key in [
'proposals', 'gt_bboxes', 'gt_bboxes_3d', 'gt_bboxes_ignore',
'gt_labels', 'gt_labels_3d'
'gt_labels', 'gt_labels_3d', 'pts_instance_mask',
'pts_semantic_mask'
]:
if key not in results:
continue
......
......@@ -92,8 +92,8 @@ class IndoorLoadAnnotations3D(object):
mmcv.check_file_exist(pts_instance_mask_path)
mmcv.check_file_exist(pts_semantic_mask_path)
pts_instance_mask = np.load(pts_instance_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path)
pts_instance_mask = np.load(pts_instance_mask_path).astype(np.int)
pts_semantic_mask = np.load(pts_semantic_mask_path).astype(np.int)
results['pts_instance_mask'] = pts_instance_mask
results['pts_semantic_mask'] = pts_semantic_mask
......
......@@ -66,8 +66,8 @@ def test_scannet_pipeline():
points = results['points']._data
gt_bboxes_3d = results['gt_bboxes_3d']._data
gt_labels_3d = results['gt_labels_3d']._data
pts_semantic_mask = results['pts_semantic_mask']
pts_instance_mask = results['pts_instance_mask']
pts_semantic_mask = results['pts_semantic_mask']._data
pts_instance_mask = results['pts_instance_mask']._data
expected_points = np.array(
[[-2.9078157, -1.9569951, 2.3543026, 2.389488],
[-0.71360034, -3.4359822, 2.1330001, 2.1681855],
......@@ -90,8 +90,8 @@ def test_scannet_pipeline():
assert np.allclose(points, expected_points)
assert np.allclose(gt_bboxes_3d[:5, :], expected_gt_bboxes_3d)
assert np.all(gt_labels_3d.numpy() == expected_gt_labels_3d)
assert np.all(pts_semantic_mask == expected_pts_semantic_mask)
assert np.all(pts_instance_mask == expected_pts_instance_mask)
assert np.all(pts_semantic_mask.numpy() == expected_pts_semantic_mask)
assert np.all(pts_instance_mask.numpy() == expected_pts_instance_mask)
def test_sunrgbd_pipeline():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment