Commit 66d883f2 authored by ChaimZhu's avatar ChaimZhu
Browse files

[FIx] fix s3dis dataset bug (#1683)

* fix s3dis dataset bug

* fix class

* fix ceph replace
parent 6f70f7e7
......@@ -58,6 +58,12 @@ test_pipeline = [
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
......@@ -105,17 +111,14 @@ train_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_files=[
data_root + f's3dis_infos_Area_{i}.pkl' for i in train_area
],
ann_files=[f's3dis_infos_Area_{i}.pkl' for i in train_area],
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=train_pipeline,
modality=input_modality,
ignore_index=len(class_names),
scene_idxs=[
data_root + f'seg_info/Area_{i}_resampled_scene_idxs.npy'
for i in train_area
f'seg_info/Area_{i}_resampled_scene_idxs.npy' for i in train_area
],
test_mode=False))
test_dataloader = dict(
......@@ -127,14 +130,13 @@ test_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_files=data_root + f's3dis_infos_Area_{test_area}.pkl',
ann_files=f's3dis_infos_Area_{test_area}.pkl',
metainfo=metainfo,
data_prefix=data_prefix,
pipeline=test_pipeline,
modality=input_modality,
ignore_index=len(class_names),
scene_idxs=data_root +
f'seg_info/Area_{test_area}_resampled_scene_idxs.npy',
scene_idxs=f'seg_info/Area_{test_area}_resampled_scene_idxs.npy',
test_mode=True))
val_dataloader = test_dataloader
......
# dataset settings
dataset_type = 'ScanNetDataset'
data_root = './data/scannet/'
data_root = 'data/scannet/'
metainfo = dict(
CLASSES=('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
......
......@@ -298,6 +298,7 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index=ignore_index,
scene_idxs=scene_idxs[0],
test_mode=test_mode,
serialize_data=False,
**kwargs)
datasets = [
......@@ -311,6 +312,7 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index=ignore_index,
scene_idxs=scene_idxs[i],
test_mode=test_mode,
serialize_data=False,
**kwargs) for i in range(len(ann_files))
]
......
......@@ -89,6 +89,7 @@ class Seg3DDataset(BaseDataset):
metainfo['label_mapping'] = self.label_mapping
metainfo['label2cat'] = self.label2cat
metainfo['ignore_index'] = self.ignore_index
metainfo['seg_valid_class_ids'] = seg_valid_class_ids
# generate palette if it is not defined based on
......@@ -145,7 +146,7 @@ class Seg3DDataset(BaseDataset):
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
"""
old_classes = self.METAINFO.get('CLASSSES', None)
old_classes = self.METAINFO.get('CLASSES', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
if not set(new_classes).issubset(old_classes):
......@@ -199,7 +200,7 @@ class Seg3DDataset(BaseDataset):
if palette is None:
# If palette is not defined, it generate a palette according
# to the original PALETTE and classes.
old_classes = self.METAINFO.get('CLASSSES', None)
old_classes = self.METAINFO.get('CLASSES', None)
palette = [
self.METAINFO['PALETTE'][old_classes.index(cls_name)]
for cls_name in new_classes
......@@ -266,13 +267,14 @@ class Seg3DDataset(BaseDataset):
"""
if self.test_mode:
# when testing, we load one whole scene every time
return np.arange(len(self.data_list)).astype(np.int32)
return np.arange(len(self)).astype(np.int32)
# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
if scene_idxs is None:
scene_idxs = np.arange(len(self.data_list))
scene_idxs = np.arange(len(self))
if isinstance(scene_idxs, str):
scene_idxs = osp.join(self.data_root, scene_idxs)
with self.file_client.get_local_path(scene_idxs) as local_path:
scene_idxs = np.load(local_path)
else:
......
......@@ -88,8 +88,8 @@ def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
hist_list = []
for i in range(len(gt_labels)):
gt_seg = gt_labels[i].clone().numpy().astype(np.int)
pred_seg = seg_preds[i].clone().numpy().astype(np.int)
gt_seg = gt_labels[i].astype(np.int)
pred_seg = seg_preds[i].astype(np.int)
# filter out ignored points
pred_seg[gt_seg == ignore_index] = -1
......
......@@ -51,7 +51,7 @@ class SegMetric(BaseMetric):
cpu_pred_3d = dict()
for k, v in pred_3d.items():
if hasattr(v, 'to'):
cpu_pred_3d[k] = v.to('cpu')
cpu_pred_3d[k] = v.to('cpu').numpy()
else:
cpu_pred_3d[k] = v
self.results.append((eval_ann, cpu_pred_3d))
......
......@@ -4,34 +4,42 @@
def replace_ceph_backend(cfg):
cfg_pretty_text = cfg.pretty_text
replace_strs = r'''file_client_args = dict(
replace_strs = \
r'''file_client_args = dict(
backend='petrel',
path_mapping=dict({
'.data/INPLACEHOLD/':
's3://openmmlab/datasets/detection3d/INPLACEHOLD/',
'data/INPLACEHOLD/':
's3://openmmlab/datasets/detection3d/INPLACEHOLD/'
'.data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/',
'data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/'
}))
'''
if 'nuscenes' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'nuscenes')
replace_strs = replace_strs.replace('DATA', 'nuscenes')
replace_strs = replace_strs.replace('CEPH', 'nuscenes')
elif 'lyft' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'lyft')
replace_strs = replace_strs.replace('DATA', 'lyft')
replace_strs = replace_strs.replace('CEPH', 'lyft')
elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'kitti')
replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'waymo' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'waymo')
replace_strs = replace_strs.replace('DATA', 'waymo')
replace_strs = replace_strs.replace('CEPH', 'waymo')
elif 'scannet' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'scannet_processed')
replace_strs = replace_strs.replace('DATA', 'scannet')
replace_strs = replace_strs.replace('CEPH', 'scannet_processed')
elif 's3dis' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 's3dis_processed')
replace_strs = replace_strs.replace('DATA', 's3dis')
replace_strs = replace_strs.replace('CEPH', 's3dis_processed')
elif 'sunrgbd' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'sunrgbd')
replace_strs = replace_strs.replace('DATA', 'sunrgbd')
replace_strs = replace_strs.replace('CEPH', 'sunrgbd_processed')
elif 'semantickitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'semantickitti')
replace_strs = replace_strs.replace('DATA', 'semantickitti')
replace_strs = replace_strs.replace('CEPH', 'semantickitti')
elif 'nuimages' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'nuimages')
replace_strs = replace_strs.replace('DATA', 'nuimages')
replace_strs = replace_strs.replace('CEPH', 'nuimages')
else:
NotImplemented('Does not support global replacement')
......@@ -42,7 +50,6 @@ def replace_ceph_backend(cfg):
# 'ann_file', replace_strs + ', ann_file')
# replace LoadImageFromFile
replace_strs = replace_strs.replace(' ', '').replace('\n', '')
cfg_pretty_text = cfg_pretty_text.replace(
'LoadImageFromFile\'', 'LoadImageFromFile\',' + replace_strs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment