Commit 7a6deaef authored by ChaimZhu's avatar ChaimZhu Committed by ZwwWayne
Browse files

[Refactor] rename `CLASSES` and `PALETTE` to `classes` and `palette` in dataset metainfo (#1932)

* rame CLASS and PALETTE to class and palette

* change mmcv-full to mmcv

* fix comments
parent 48ab8e2d
......@@ -44,7 +44,7 @@ class KittiDataset(Det3DDataset):
"""
# TODO: use full classes of kitti
METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'classes': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
}
......
......@@ -42,7 +42,7 @@ class LyftDataset(Det3DDataset):
"""
METAINFO = {
'CLASSES':
'classes':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal')
}
......
......@@ -48,7 +48,7 @@ class NuScenesDataset(Det3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES':
'classes':
('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
'version':
......
......@@ -43,7 +43,7 @@ class S3DISDataset(Det3DDataset):
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
"""
CLASSES = ('table', 'chair', 'sofa', 'bookcase', 'board')
classes = ('table', 'chair', 'sofa', 'bookcase', 'board')
def __init__(self,
data_root,
......@@ -146,7 +146,7 @@ class S3DISDataset(Det3DDataset):
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
class_names=self.classes,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
......@@ -187,10 +187,10 @@ class _S3DISSegDataset(Seg3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES':
'classes':
('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'),
'PALETTE': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
'palette': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
[255, 0, 255], [100, 100, 255], [200, 200, 100],
[170, 120, 200], [255, 0, 0], [200, 100, 100],
[10, 200, 100], [200, 200, 200], [50, 50, 50]],
......
......@@ -49,7 +49,7 @@ class ScanNetDataset(Det3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES':
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
......@@ -204,12 +204,12 @@ class ScanNetSegDataset(Seg3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES':
'classes':
('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'otherfurniture'),
'PALETTE': [
'palette': [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
......@@ -278,11 +278,11 @@ class ScanNetSegDataset(Seg3DDataset):
class ScanNetInstanceSegDataset(Seg3DDataset):
METAINFO = {
'CLASSES':
'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
'PLATTE': [
'palette': [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
......
......@@ -31,7 +31,7 @@ class Seg3DDataset(BaseDataset):
- use_lidar: bool
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to
unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load
......@@ -50,8 +50,8 @@ class Seg3DDataset(BaseDataset):
Defaults to dict(backend='disk').
"""
METAINFO = {
'CLASSES': None, # names of all classes data used for the task
'PALETTE': None, # official color for visualization
'classes': None, # names of all classes data used for the task
'palette': None, # official color for visualization
'seg_valid_class_ids': None, # class_ids used for training
'seg_all_class_ids': None, # all possible class_ids in loaded seg mask
}
......@@ -81,11 +81,11 @@ class Seg3DDataset(BaseDataset):
# TODO: We maintain the ignore_index attributes,
# but we may consider to remove it in the future.
self.ignore_index = len(self.METAINFO['CLASSES']) if \
self.ignore_index = len(self.METAINFO['classes']) if \
ignore_index is None else ignore_index
# Get label mapping for custom classes
new_classes = metainfo.get('CLASSES', None)
new_classes = metainfo.get('classes', None)
self.label_mapping, self.label2cat, seg_valid_class_ids = \
self.get_label_mapping(new_classes)
......@@ -98,10 +98,10 @@ class Seg3DDataset(BaseDataset):
# generate palette if it is not defined based on
# label mapping, otherwise directly use palette
# defined in dataset config.
palette = metainfo.get('PALETTE', None)
palette = metainfo.get('palette', None)
updated_palette = self._update_palette(new_classes, palette)
metainfo['PALETTE'] = updated_palette
metainfo['palette'] = updated_palette
# construct seg_label_mapping for semantic mask
seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
......@@ -150,13 +150,13 @@ class Seg3DDataset(BaseDataset):
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
"""
old_classes = self.METAINFO.get('CLASSES', None)
old_classes = self.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
if not set(new_classes).issubset(old_classes):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of CLASSES {old_classes} in METAINFO.')
f'subset of classes {old_classes} in METAINFO.')
# obtain true id from valid_class_ids
valid_class_ids = [
......@@ -184,7 +184,7 @@ class Seg3DDataset(BaseDataset):
# map label to category name
label2cat = {
i: cat_name
for i, cat_name in enumerate(self.METAINFO['CLASSES'])
for i, cat_name in enumerate(self.METAINFO['classes'])
}
valid_class_ids = self.METAINFO['seg_valid_class_ids']
......@@ -203,10 +203,10 @@ class Seg3DDataset(BaseDataset):
"""
if palette is None:
# If palette is not defined, it generate a palette according
# to the original PALETTE and classes.
old_classes = self.METAINFO.get('CLASSES', None)
# to the original palette and classes.
old_classes = self.METAINFO.get('classes', None)
palette = [
self.METAINFO['PALETTE'][old_classes.index(cls_name)]
self.METAINFO['palette'][old_classes.index(cls_name)]
for cls_name in new_classes
]
return palette
......@@ -215,8 +215,8 @@ class Seg3DDataset(BaseDataset):
if len(palette) == len(new_classes):
return palette
else:
raise ValueError('Once PLATTE in set in metainfo, it should'
'match CLASSES in metainfo')
raise ValueError('Once palette in set in metainfo, it should'
'match classes in metainfo')
def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info.
......
......@@ -41,7 +41,7 @@ class SemanticKITTIDataset(Seg3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck',
'classes': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck',
'bus', 'person', 'bicyclist', 'motorcyclist', 'road',
'parking', 'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'),
......
......@@ -46,7 +46,7 @@ class SUNRGBDDataset(Det3DDataset):
Defaults to False.
"""
METAINFO = {
'CLASSES': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'classes': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub')
}
......
......@@ -63,7 +63,7 @@ class WaymoDataset(KittiDataset):
Defaults to 'lidar_det'.
max_sweeps (int): max sweep for each frame. Defaults to 0.
"""
METAINFO = {'CLASSES': ('Car', 'Pedestrian', 'Cyclist')}
METAINFO = {'classes': ('Car', 'Pedestrian', 'Cyclist')}
def __init__(self,
data_root: str,
......@@ -91,7 +91,7 @@ class WaymoDataset(KittiDataset):
# set loading mode for different task settings
self.cam_sync_instances = cam_sync_instances
# construct self.cat_ids for vision-only anns parsing
self.cat_ids = range(len(self.METAINFO['CLASSES']))
self.cat_ids = range(len(self.METAINFO['classes']))
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.max_sweeps = max_sweeps
# we do not provide file_client_args to custom_3d init
......
......@@ -86,7 +86,7 @@ class IndoorMetric(BaseMetric):
ann_infos,
pred_results,
self.iou_thr,
self.dataset_meta['CLASSES'],
self.dataset_meta['classes'],
logger=logger,
box_mode_3d=box_mode_3d)
......@@ -142,7 +142,7 @@ class Indoor2DMetric(BaseMetric):
pred_labels = pred['labels'].cpu().numpy()
dets = []
for label in range(len(self.dataset_meta['CLASSES'])):
for label in range(len(self.dataset_meta['classes'])):
index = np.where(pred_labels == label)[0]
pred_bbox_scores = np.hstack(
[pred_bboxes[index], pred_scores[index].reshape((-1, 1))])
......@@ -171,7 +171,7 @@ class Indoor2DMetric(BaseMetric):
annotations,
scale_ranges=None,
iou_thr=iou_thr_2d_single,
dataset=self.dataset_meta['CLASSES'],
dataset=self.dataset_meta['classes'],
logger=logger)
eval_results['mAP_' + str(iou_thr_2d_single)] = mean_ap
return eval_results
......@@ -64,7 +64,7 @@ class InstanceSegMetric(BaseMetric):
"""
logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES']
self.classes = self.dataset_meta['classes']
self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
gt_semantic_masks = []
......
......@@ -167,7 +167,7 @@ class KittiMetric(BaseMetric):
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES']
self.classes = self.dataset_meta['classes']
# load annotations
pkl_infos = load(self.ann_file, file_client_args=self.file_client_args)
......
......@@ -110,7 +110,7 @@ class LyftMetric(BaseMetric):
"""
logger: MMLogger = MMLogger.get_current_instance()
classes = self.dataset_meta['CLASSES']
classes = self.dataset_meta['classes']
self.version = self.dataset_meta['version']
# load annotations
......
......@@ -151,7 +151,7 @@ class NuScenesMetric(BaseMetric):
"""
logger: MMLogger = MMLogger.get_current_instance()
classes = self.dataset_meta['CLASSES']
classes = self.dataset_meta['classes']
self.version = self.dataset_meta['version']
# load annotations
self.data_infos = load(
......
......@@ -100,7 +100,7 @@ class WaymoMetric(KittiMetric):
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES']
self.classes = self.dataset_meta['classes']
# load annotations
self.data_infos = load(self.ann_file)['data_list']
......@@ -379,7 +379,7 @@ class WaymoMetric(KittiMetric):
torch.from_numpy(box_dict['box3d_lidar']).cuda())
scores = torch.from_numpy(box_dict['scores']).cuda()
labels = torch.from_numpy(box_dict['label_preds']).long().cuda()
nms_scores = scores.new_zeros(scores.shape[0], len(self.CLASSES) + 1)
nms_scores = scores.new_zeros(scores.shape[0], len(self.classes) + 1)
indices = labels.new_tensor(list(range(scores.shape[0])))
nms_scores[indices, labels] = scores
lidar_boxes3d_for_nms = xywhr2xyxyr(lidar_boxes3d.bev)
......
......@@ -708,9 +708,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0.
"""
classes = self.dataset_meta.get('CLASSES', None)
# For object detection datasets, no PALETTE is saved
palette = self.dataset_meta.get('PALETTE', None)
classes = self.dataset_meta.get('classes', None)
# For object detection datasets, no palette is saved
palette = self.dataset_meta.get('palette', None)
ignore_index = self.dataset_meta.get('ignore_index', None)
gt_data_3d = None
......
......@@ -54,7 +54,7 @@ def test_getitem():
img='training/image_2',
),
pipeline=pipeline,
metainfo=dict(CLASSES=classes),
metainfo=dict(classes=classes),
modality=modality)
kitti_dataset.prepare_data(0)
......@@ -94,7 +94,7 @@ def test_getitem():
img='training/image_2',
),
pipeline=pipeline,
metainfo=dict(CLASSES=['Car']),
metainfo=dict(classes=['Car']),
modality=modality)
input_dict = car_kitti_dataset.get_data_info(0)
......@@ -105,4 +105,4 @@ def test_getitem():
assert ann_info['gt_labels_3d'].dtype == np.int64
# all instance have been filtered by classes
assert len(ann_info['gt_labels_3d']) == 0
assert len(car_kitti_dataset.metainfo['CLASSES']) == 1
assert len(car_kitti_dataset.metainfo['classes']) == 1
......@@ -48,7 +48,7 @@ def test_getitem():
ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(CLASSES=classes),
metainfo=dict(classes=classes),
modality=modality)
lyft_dataset.prepare_data(0)
......@@ -68,4 +68,4 @@ def test_getitem():
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(lyft_dataset.metainfo['CLASSES']) == 9
assert len(lyft_dataset.metainfo['classes']) == 9
......@@ -51,7 +51,7 @@ def test_getitem():
ann_file=ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(CLASSES=classes),
metainfo=dict(classes=classes),
modality=modality)
nus_dataset.prepare_data(0)
......@@ -77,7 +77,7 @@ def test_getitem():
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(nus_dataset.metainfo['CLASSES']) == 10
assert len(nus_dataset.metainfo['classes']) == 10
assert input_dict['token'] == 'fd8420396768425eabec9bdddf7e64b6'
assert input_dict['timestamp'] == 1533201470.448696
......@@ -67,7 +67,7 @@ class TestS3DISDataset(unittest.TestCase):
s3dis_seg_dataset = S3DISSegDataset(
data_root,
ann_file,
metainfo=dict(CLASSES=classes, PALETTE=palette),
metainfo=dict(classes=classes, palette=palette),
data_prefix=data_prefix,
pipeline=pipeline,
modality=modality,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment