Commit 66d883f2 authored by ChaimZhu's avatar ChaimZhu
Browse files

[FIx] fix s3dis dataset bug (#1683)

* fix s3dis dataset bug

* fix class

* fix ceph replace
parent 6f70f7e7
...@@ -58,6 +58,12 @@ test_pipeline = [ ...@@ -58,6 +58,12 @@ test_pipeline = [
use_color=True, use_color=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='NormalizePointsColor', color_mean=None), dict(type='NormalizePointsColor', color_mean=None),
dict( dict(
# a wrapper in order to successfully call test function # a wrapper in order to successfully call test function
...@@ -105,17 +111,14 @@ train_dataloader = dict( ...@@ -105,17 +111,14 @@ train_dataloader = dict(
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_files=[ ann_files=[f's3dis_infos_Area_{i}.pkl' for i in train_area],
data_root + f's3dis_infos_Area_{i}.pkl' for i in train_area
],
metainfo=metainfo, metainfo=metainfo,
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=train_pipeline, pipeline=train_pipeline,
modality=input_modality, modality=input_modality,
ignore_index=len(class_names), ignore_index=len(class_names),
scene_idxs=[ scene_idxs=[
data_root + f'seg_info/Area_{i}_resampled_scene_idxs.npy' f'seg_info/Area_{i}_resampled_scene_idxs.npy' for i in train_area
for i in train_area
], ],
test_mode=False)) test_mode=False))
test_dataloader = dict( test_dataloader = dict(
...@@ -127,14 +130,13 @@ test_dataloader = dict( ...@@ -127,14 +130,13 @@ test_dataloader = dict(
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_files=data_root + f's3dis_infos_Area_{test_area}.pkl', ann_files=f's3dis_infos_Area_{test_area}.pkl',
metainfo=metainfo, metainfo=metainfo,
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=test_pipeline, pipeline=test_pipeline,
modality=input_modality, modality=input_modality,
ignore_index=len(class_names), ignore_index=len(class_names),
scene_idxs=data_root + scene_idxs=f'seg_info/Area_{test_area}_resampled_scene_idxs.npy',
f'seg_info/Area_{test_area}_resampled_scene_idxs.npy',
test_mode=True)) test_mode=True))
val_dataloader = test_dataloader val_dataloader = test_dataloader
......
# dataset settings # dataset settings
dataset_type = 'ScanNetDataset' dataset_type = 'ScanNetDataset'
data_root = './data/scannet/' data_root = 'data/scannet/'
metainfo = dict( metainfo = dict(
CLASSES=('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', CLASSES=('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
......
...@@ -298,6 +298,7 @@ class S3DISSegDataset(_S3DISSegDataset): ...@@ -298,6 +298,7 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs[0], scene_idxs=scene_idxs[0],
test_mode=test_mode, test_mode=test_mode,
serialize_data=False,
**kwargs) **kwargs)
datasets = [ datasets = [
...@@ -311,6 +312,7 @@ class S3DISSegDataset(_S3DISSegDataset): ...@@ -311,6 +312,7 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index=ignore_index, ignore_index=ignore_index,
scene_idxs=scene_idxs[i], scene_idxs=scene_idxs[i],
test_mode=test_mode, test_mode=test_mode,
serialize_data=False,
**kwargs) for i in range(len(ann_files)) **kwargs) for i in range(len(ann_files))
] ]
......
...@@ -89,6 +89,7 @@ class Seg3DDataset(BaseDataset): ...@@ -89,6 +89,7 @@ class Seg3DDataset(BaseDataset):
metainfo['label_mapping'] = self.label_mapping metainfo['label_mapping'] = self.label_mapping
metainfo['label2cat'] = self.label2cat metainfo['label2cat'] = self.label2cat
metainfo['ignore_index'] = self.ignore_index
metainfo['seg_valid_class_ids'] = seg_valid_class_ids metainfo['seg_valid_class_ids'] = seg_valid_class_ids
# generate palette if it is not defined based on # generate palette if it is not defined based on
...@@ -145,7 +146,7 @@ class Seg3DDataset(BaseDataset): ...@@ -145,7 +146,7 @@ class Seg3DDataset(BaseDataset):
tuple: The mapping from old classes in cls.METAINFO to tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo new classes in metainfo
""" """
old_classes = self.METAINFO.get('CLASSSES', None) old_classes = self.METAINFO.get('CLASSES', None)
if (new_classes is not None and old_classes is not None if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)): and list(new_classes) != list(old_classes)):
if not set(new_classes).issubset(old_classes): if not set(new_classes).issubset(old_classes):
...@@ -199,7 +200,7 @@ class Seg3DDataset(BaseDataset): ...@@ -199,7 +200,7 @@ class Seg3DDataset(BaseDataset):
if palette is None: if palette is None:
# If palette is not defined, it generate a palette according # If palette is not defined, it generate a palette according
# to the original PALETTE and classes. # to the original PALETTE and classes.
old_classes = self.METAINFO.get('CLASSSES', None) old_classes = self.METAINFO.get('CLASSES', None)
palette = [ palette = [
self.METAINFO['PALETTE'][old_classes.index(cls_name)] self.METAINFO['PALETTE'][old_classes.index(cls_name)]
for cls_name in new_classes for cls_name in new_classes
...@@ -266,13 +267,14 @@ class Seg3DDataset(BaseDataset): ...@@ -266,13 +267,14 @@ class Seg3DDataset(BaseDataset):
""" """
if self.test_mode: if self.test_mode:
# when testing, we load one whole scene every time # when testing, we load one whole scene every time
return np.arange(len(self.data_list)).astype(np.int32) return np.arange(len(self)).astype(np.int32)
# we may need to re-sample different scenes according to scene_idxs # we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet # this is necessary for indoor scene segmentation such as ScanNet
if scene_idxs is None: if scene_idxs is None:
scene_idxs = np.arange(len(self.data_list)) scene_idxs = np.arange(len(self))
if isinstance(scene_idxs, str): if isinstance(scene_idxs, str):
scene_idxs = osp.join(self.data_root, scene_idxs)
with self.file_client.get_local_path(scene_idxs) as local_path: with self.file_client.get_local_path(scene_idxs) as local_path:
scene_idxs = np.load(local_path) scene_idxs = np.load(local_path)
else: else:
......
...@@ -88,8 +88,8 @@ def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None): ...@@ -88,8 +88,8 @@ def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
hist_list = [] hist_list = []
for i in range(len(gt_labels)): for i in range(len(gt_labels)):
gt_seg = gt_labels[i].clone().numpy().astype(np.int) gt_seg = gt_labels[i].astype(np.int)
pred_seg = seg_preds[i].clone().numpy().astype(np.int) pred_seg = seg_preds[i].astype(np.int)
# filter out ignored points # filter out ignored points
pred_seg[gt_seg == ignore_index] = -1 pred_seg[gt_seg == ignore_index] = -1
......
...@@ -51,7 +51,7 @@ class SegMetric(BaseMetric): ...@@ -51,7 +51,7 @@ class SegMetric(BaseMetric):
cpu_pred_3d = dict() cpu_pred_3d = dict()
for k, v in pred_3d.items(): for k, v in pred_3d.items():
if hasattr(v, 'to'): if hasattr(v, 'to'):
cpu_pred_3d[k] = v.to('cpu') cpu_pred_3d[k] = v.to('cpu').numpy()
else: else:
cpu_pred_3d[k] = v cpu_pred_3d[k] = v
self.results.append((eval_ann, cpu_pred_3d)) self.results.append((eval_ann, cpu_pred_3d))
......
...@@ -4,34 +4,42 @@ ...@@ -4,34 +4,42 @@
def replace_ceph_backend(cfg): def replace_ceph_backend(cfg):
cfg_pretty_text = cfg.pretty_text cfg_pretty_text = cfg.pretty_text
replace_strs = r'''file_client_args = dict( replace_strs = \
r'''file_client_args = dict(
backend='petrel', backend='petrel',
path_mapping=dict({ path_mapping=dict({
'.data/INPLACEHOLD/': '.data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/',
's3://openmmlab/datasets/detection3d/INPLACEHOLD/', 'data/DATA/': 's3://openmmlab/datasets/detection3d/CEPH/'
'data/INPLACEHOLD/':
's3://openmmlab/datasets/detection3d/INPLACEHOLD/'
})) }))
''' '''
if 'nuscenes' in cfg_pretty_text: if 'nuscenes' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'nuscenes') replace_strs = replace_strs.replace('DATA', 'nuscenes')
replace_strs = replace_strs.replace('CEPH', 'nuscenes')
elif 'lyft' in cfg_pretty_text: elif 'lyft' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'lyft') replace_strs = replace_strs.replace('DATA', 'lyft')
replace_strs = replace_strs.replace('CEPH', 'lyft')
elif 'kitti' in cfg_pretty_text: elif 'kitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'kitti') replace_strs = replace_strs.replace('DATA', 'kitti')
replace_strs = replace_strs.replace('CEPH', 'kitti')
elif 'waymo' in cfg_pretty_text: elif 'waymo' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'waymo') replace_strs = replace_strs.replace('DATA', 'waymo')
replace_strs = replace_strs.replace('CEPH', 'waymo')
elif 'scannet' in cfg_pretty_text: elif 'scannet' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'scannet_processed') replace_strs = replace_strs.replace('DATA', 'scannet')
replace_strs = replace_strs.replace('CEPH', 'scannet_processed')
elif 's3dis' in cfg_pretty_text: elif 's3dis' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 's3dis_processed') replace_strs = replace_strs.replace('DATA', 's3dis')
replace_strs = replace_strs.replace('CEPH', 's3dis_processed')
elif 'sunrgbd' in cfg_pretty_text: elif 'sunrgbd' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'sunrgbd') replace_strs = replace_strs.replace('DATA', 'sunrgbd')
replace_strs = replace_strs.replace('CEPH', 'sunrgbd_processed')
elif 'semantickitti' in cfg_pretty_text: elif 'semantickitti' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'semantickitti') replace_strs = replace_strs.replace('DATA', 'semantickitti')
replace_strs = replace_strs.replace('CEPH', 'semantickitti')
elif 'nuimages' in cfg_pretty_text: elif 'nuimages' in cfg_pretty_text:
replace_strs = replace_strs.replace('INPLACEHOLD', 'nuimages') replace_strs = replace_strs.replace('DATA', 'nuimages')
replace_strs = replace_strs.replace('CEPH', 'nuimages')
else: else:
NotImplemented('Does not support global replacement') NotImplemented('Does not support global replacement')
...@@ -42,7 +50,6 @@ def replace_ceph_backend(cfg): ...@@ -42,7 +50,6 @@ def replace_ceph_backend(cfg):
# 'ann_file', replace_strs + ', ann_file') # 'ann_file', replace_strs + ', ann_file')
# replace LoadImageFromFile # replace LoadImageFromFile
replace_strs = replace_strs.replace(' ', '').replace('\n', '')
cfg_pretty_text = cfg_pretty_text.replace( cfg_pretty_text = cfg_pretty_text.replace(
'LoadImageFromFile\'', 'LoadImageFromFile\',' + replace_strs) 'LoadImageFromFile\'', 'LoadImageFromFile\',' + replace_strs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment