Commit 7a6deaef authored by ChaimZhu's avatar ChaimZhu Committed by ZwwWayne
Browse files

[Refactor] rename `CLASSES` and `PALETTE` to `classes` and `palette` in dataset metainfo (#1932)

* rame CLASS and PALETTE to class and palette

* change mmcv-full to mmcv

* fix comments
parent 48ab8e2d
...@@ -44,7 +44,7 @@ class KittiDataset(Det3DDataset): ...@@ -44,7 +44,7 @@ class KittiDataset(Det3DDataset):
""" """
# TODO: use full classes of kitti # TODO: use full classes of kitti
METAINFO = { METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck', 'classes': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc') 'Person_sitting', 'Tram', 'Misc')
} }
......
...@@ -42,7 +42,7 @@ class LyftDataset(Det3DDataset): ...@@ -42,7 +42,7 @@ class LyftDataset(Det3DDataset):
""" """
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle', ('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal') 'motorcycle', 'bicycle', 'pedestrian', 'animal')
} }
......
...@@ -48,7 +48,7 @@ class NuScenesDataset(Det3DDataset): ...@@ -48,7 +48,7 @@ class NuScenesDataset(Det3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle', ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'), 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'),
'version': 'version':
......
...@@ -43,7 +43,7 @@ class S3DISDataset(Det3DDataset): ...@@ -43,7 +43,7 @@ class S3DISDataset(Det3DDataset):
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
""" """
CLASSES = ('table', 'chair', 'sofa', 'bookcase', 'board') classes = ('table', 'chair', 'sofa', 'bookcase', 'board')
def __init__(self, def __init__(self,
data_root, data_root,
...@@ -146,7 +146,7 @@ class S3DISDataset(Det3DDataset): ...@@ -146,7 +146,7 @@ class S3DISDataset(Det3DDataset):
use_dim=[0, 1, 2, 3, 4, 5]), use_dim=[0, 1, 2, 3, 4, 5]),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=self.CLASSES, class_names=self.classes,
with_label=False), with_label=False),
dict(type='Collect3D', keys=['points']) dict(type='Collect3D', keys=['points'])
] ]
...@@ -187,10 +187,10 @@ class _S3DISSegDataset(Seg3DDataset): ...@@ -187,10 +187,10 @@ class _S3DISSegDataset(Seg3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'), 'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter'),
'PALETTE': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0], 'palette': [[0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 255, 0],
[255, 0, 255], [100, 100, 255], [200, 200, 100], [255, 0, 255], [100, 100, 255], [200, 200, 100],
[170, 120, 200], [255, 0, 0], [200, 100, 100], [170, 120, 200], [255, 0, 0], [200, 100, 100],
[10, 200, 100], [200, 200, 200], [50, 50, 50]], [10, 200, 100], [200, 200, 200], [50, 50, 50]],
......
...@@ -49,7 +49,7 @@ class ScanNetDataset(Det3DDataset): ...@@ -49,7 +49,7 @@ class ScanNetDataset(Det3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'), 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
...@@ -204,12 +204,12 @@ class ScanNetSegDataset(Seg3DDataset): ...@@ -204,12 +204,12 @@ class ScanNetSegDataset(Seg3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'otherfurniture'), 'otherfurniture'),
'PALETTE': [ 'palette': [
[174, 199, 232], [174, 199, 232],
[152, 223, 138], [152, 223, 138],
[31, 119, 180], [31, 119, 180],
...@@ -278,11 +278,11 @@ class ScanNetSegDataset(Seg3DDataset): ...@@ -278,11 +278,11 @@ class ScanNetSegDataset(Seg3DDataset):
class ScanNetInstanceSegDataset(Seg3DDataset): class ScanNetInstanceSegDataset(Seg3DDataset):
METAINFO = { METAINFO = {
'CLASSES': 'classes':
('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'), 'showercurtrain', 'toilet', 'sink', 'bathtub', 'garbagebin'),
'PLATTE': [ 'palette': [
[174, 199, 232], [174, 199, 232],
[152, 223, 138], [152, 223, 138],
[31, 119, 180], [31, 119, 180],
......
...@@ -31,7 +31,7 @@ class Seg3DDataset(BaseDataset): ...@@ -31,7 +31,7 @@ class Seg3DDataset(BaseDataset):
- use_lidar: bool - use_lidar: bool
Defaults to dict(use_lidar=True, use_camera=False). Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g. ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES) to unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline. be consistent with PointSegClassMapping function in pipeline.
Defaults to None. Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load scene_idxs (np.ndarray | str, optional): Precomputed index to load
...@@ -50,8 +50,8 @@ class Seg3DDataset(BaseDataset): ...@@ -50,8 +50,8 @@ class Seg3DDataset(BaseDataset):
Defaults to dict(backend='disk'). Defaults to dict(backend='disk').
""" """
METAINFO = { METAINFO = {
'CLASSES': None, # names of all classes data used for the task 'classes': None, # names of all classes data used for the task
'PALETTE': None, # official color for visualization 'palette': None, # official color for visualization
'seg_valid_class_ids': None, # class_ids used for training 'seg_valid_class_ids': None, # class_ids used for training
'seg_all_class_ids': None, # all possible class_ids in loaded seg mask 'seg_all_class_ids': None, # all possible class_ids in loaded seg mask
} }
...@@ -81,11 +81,11 @@ class Seg3DDataset(BaseDataset): ...@@ -81,11 +81,11 @@ class Seg3DDataset(BaseDataset):
# TODO: We maintain the ignore_index attributes, # TODO: We maintain the ignore_index attributes,
# but we may consider to remove it in the future. # but we may consider to remove it in the future.
self.ignore_index = len(self.METAINFO['CLASSES']) if \ self.ignore_index = len(self.METAINFO['classes']) if \
ignore_index is None else ignore_index ignore_index is None else ignore_index
# Get label mapping for custom classes # Get label mapping for custom classes
new_classes = metainfo.get('CLASSES', None) new_classes = metainfo.get('classes', None)
self.label_mapping, self.label2cat, seg_valid_class_ids = \ self.label_mapping, self.label2cat, seg_valid_class_ids = \
self.get_label_mapping(new_classes) self.get_label_mapping(new_classes)
...@@ -98,10 +98,10 @@ class Seg3DDataset(BaseDataset): ...@@ -98,10 +98,10 @@ class Seg3DDataset(BaseDataset):
# generate palette if it is not defined based on # generate palette if it is not defined based on
# label mapping, otherwise directly use palette # label mapping, otherwise directly use palette
# defined in dataset config. # defined in dataset config.
palette = metainfo.get('PALETTE', None) palette = metainfo.get('palette', None)
updated_palette = self._update_palette(new_classes, palette) updated_palette = self._update_palette(new_classes, palette)
metainfo['PALETTE'] = updated_palette metainfo['palette'] = updated_palette
# construct seg_label_mapping for semantic mask # construct seg_label_mapping for semantic mask
seg_max_cat_id = len(self.METAINFO['seg_all_class_ids']) seg_max_cat_id = len(self.METAINFO['seg_all_class_ids'])
...@@ -150,13 +150,13 @@ class Seg3DDataset(BaseDataset): ...@@ -150,13 +150,13 @@ class Seg3DDataset(BaseDataset):
tuple: The mapping from old classes in cls.METAINFO to tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo new classes in metainfo
""" """
old_classes = self.METAINFO.get('CLASSES', None) old_classes = self.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)): and list(new_classes) != list(old_classes)):
if not set(new_classes).issubset(old_classes): if not set(new_classes).issubset(old_classes):
raise ValueError( raise ValueError(
f'new classes {new_classes} is not a ' f'new classes {new_classes} is not a '
f'subset of CLASSES {old_classes} in METAINFO.') f'subset of classes {old_classes} in METAINFO.')
# obtain true id from valid_class_ids # obtain true id from valid_class_ids
valid_class_ids = [ valid_class_ids = [
...@@ -184,7 +184,7 @@ class Seg3DDataset(BaseDataset): ...@@ -184,7 +184,7 @@ class Seg3DDataset(BaseDataset):
# map label to category name # map label to category name
label2cat = { label2cat = {
i: cat_name i: cat_name
for i, cat_name in enumerate(self.METAINFO['CLASSES']) for i, cat_name in enumerate(self.METAINFO['classes'])
} }
valid_class_ids = self.METAINFO['seg_valid_class_ids'] valid_class_ids = self.METAINFO['seg_valid_class_ids']
...@@ -203,10 +203,10 @@ class Seg3DDataset(BaseDataset): ...@@ -203,10 +203,10 @@ class Seg3DDataset(BaseDataset):
""" """
if palette is None: if palette is None:
# If palette is not defined, it generate a palette according # If palette is not defined, it generate a palette according
# to the original PALETTE and classes. # to the original palette and classes.
old_classes = self.METAINFO.get('CLASSES', None) old_classes = self.METAINFO.get('classes', None)
palette = [ palette = [
self.METAINFO['PALETTE'][old_classes.index(cls_name)] self.METAINFO['palette'][old_classes.index(cls_name)]
for cls_name in new_classes for cls_name in new_classes
] ]
return palette return palette
...@@ -215,8 +215,8 @@ class Seg3DDataset(BaseDataset): ...@@ -215,8 +215,8 @@ class Seg3DDataset(BaseDataset):
if len(palette) == len(new_classes): if len(palette) == len(new_classes):
return palette return palette
else: else:
raise ValueError('Once PLATTE in set in metainfo, it should' raise ValueError('Once palette in set in metainfo, it should'
'match CLASSES in metainfo') 'match classes in metainfo')
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
"""Process the raw data info. """Process the raw data info.
......
...@@ -41,7 +41,7 @@ class SemanticKITTIDataset(Seg3DDataset): ...@@ -41,7 +41,7 @@ class SemanticKITTIDataset(Seg3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'classes': ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck',
'bus', 'person', 'bicyclist', 'motorcyclist', 'road', 'bus', 'person', 'bicyclist', 'motorcyclist', 'road',
'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'parking', 'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'), 'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign'),
......
...@@ -46,7 +46,7 @@ class SUNRGBDDataset(Det3DDataset): ...@@ -46,7 +46,7 @@ class SUNRGBDDataset(Det3DDataset):
Defaults to False. Defaults to False.
""" """
METAINFO = { METAINFO = {
'CLASSES': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'classes': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub') 'dresser', 'night_stand', 'bookshelf', 'bathtub')
} }
......
...@@ -63,7 +63,7 @@ class WaymoDataset(KittiDataset): ...@@ -63,7 +63,7 @@ class WaymoDataset(KittiDataset):
Defaults to 'lidar_det'. Defaults to 'lidar_det'.
max_sweeps (int): max sweep for each frame. Defaults to 0. max_sweeps (int): max sweep for each frame. Defaults to 0.
""" """
METAINFO = {'CLASSES': ('Car', 'Pedestrian', 'Cyclist')} METAINFO = {'classes': ('Car', 'Pedestrian', 'Cyclist')}
def __init__(self, def __init__(self,
data_root: str, data_root: str,
...@@ -91,7 +91,7 @@ class WaymoDataset(KittiDataset): ...@@ -91,7 +91,7 @@ class WaymoDataset(KittiDataset):
# set loading mode for different task settings # set loading mode for different task settings
self.cam_sync_instances = cam_sync_instances self.cam_sync_instances = cam_sync_instances
# construct self.cat_ids for vision-only anns parsing # construct self.cat_ids for vision-only anns parsing
self.cat_ids = range(len(self.METAINFO['CLASSES'])) self.cat_ids = range(len(self.METAINFO['classes']))
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.max_sweeps = max_sweeps self.max_sweeps = max_sweeps
# we do not provide file_client_args to custom_3d init # we do not provide file_client_args to custom_3d init
......
...@@ -86,7 +86,7 @@ class IndoorMetric(BaseMetric): ...@@ -86,7 +86,7 @@ class IndoorMetric(BaseMetric):
ann_infos, ann_infos,
pred_results, pred_results,
self.iou_thr, self.iou_thr,
self.dataset_meta['CLASSES'], self.dataset_meta['classes'],
logger=logger, logger=logger,
box_mode_3d=box_mode_3d) box_mode_3d=box_mode_3d)
...@@ -142,7 +142,7 @@ class Indoor2DMetric(BaseMetric): ...@@ -142,7 +142,7 @@ class Indoor2DMetric(BaseMetric):
pred_labels = pred['labels'].cpu().numpy() pred_labels = pred['labels'].cpu().numpy()
dets = [] dets = []
for label in range(len(self.dataset_meta['CLASSES'])): for label in range(len(self.dataset_meta['classes'])):
index = np.where(pred_labels == label)[0] index = np.where(pred_labels == label)[0]
pred_bbox_scores = np.hstack( pred_bbox_scores = np.hstack(
[pred_bboxes[index], pred_scores[index].reshape((-1, 1))]) [pred_bboxes[index], pred_scores[index].reshape((-1, 1))])
...@@ -171,7 +171,7 @@ class Indoor2DMetric(BaseMetric): ...@@ -171,7 +171,7 @@ class Indoor2DMetric(BaseMetric):
annotations, annotations,
scale_ranges=None, scale_ranges=None,
iou_thr=iou_thr_2d_single, iou_thr=iou_thr_2d_single,
dataset=self.dataset_meta['CLASSES'], dataset=self.dataset_meta['classes'],
logger=logger) logger=logger)
eval_results['mAP_' + str(iou_thr_2d_single)] = mean_ap eval_results['mAP_' + str(iou_thr_2d_single)] = mean_ap
return eval_results return eval_results
...@@ -64,7 +64,7 @@ class InstanceSegMetric(BaseMetric): ...@@ -64,7 +64,7 @@ class InstanceSegMetric(BaseMetric):
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES'] self.classes = self.dataset_meta['classes']
self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
gt_semantic_masks = [] gt_semantic_masks = []
......
...@@ -167,7 +167,7 @@ class KittiMetric(BaseMetric): ...@@ -167,7 +167,7 @@ class KittiMetric(BaseMetric):
the metrics, and the values are corresponding results. the metrics, and the values are corresponding results.
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES'] self.classes = self.dataset_meta['classes']
# load annotations # load annotations
pkl_infos = load(self.ann_file, file_client_args=self.file_client_args) pkl_infos = load(self.ann_file, file_client_args=self.file_client_args)
......
...@@ -110,7 +110,7 @@ class LyftMetric(BaseMetric): ...@@ -110,7 +110,7 @@ class LyftMetric(BaseMetric):
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
classes = self.dataset_meta['CLASSES'] classes = self.dataset_meta['classes']
self.version = self.dataset_meta['version'] self.version = self.dataset_meta['version']
# load annotations # load annotations
......
...@@ -151,7 +151,7 @@ class NuScenesMetric(BaseMetric): ...@@ -151,7 +151,7 @@ class NuScenesMetric(BaseMetric):
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
classes = self.dataset_meta['CLASSES'] classes = self.dataset_meta['classes']
self.version = self.dataset_meta['version'] self.version = self.dataset_meta['version']
# load annotations # load annotations
self.data_infos = load( self.data_infos = load(
......
...@@ -100,7 +100,7 @@ class WaymoMetric(KittiMetric): ...@@ -100,7 +100,7 @@ class WaymoMetric(KittiMetric):
the metrics, and the values are corresponding results. the metrics, and the values are corresponding results.
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES'] self.classes = self.dataset_meta['classes']
# load annotations # load annotations
self.data_infos = load(self.ann_file)['data_list'] self.data_infos = load(self.ann_file)['data_list']
...@@ -379,7 +379,7 @@ class WaymoMetric(KittiMetric): ...@@ -379,7 +379,7 @@ class WaymoMetric(KittiMetric):
torch.from_numpy(box_dict['box3d_lidar']).cuda()) torch.from_numpy(box_dict['box3d_lidar']).cuda())
scores = torch.from_numpy(box_dict['scores']).cuda() scores = torch.from_numpy(box_dict['scores']).cuda()
labels = torch.from_numpy(box_dict['label_preds']).long().cuda() labels = torch.from_numpy(box_dict['label_preds']).long().cuda()
nms_scores = scores.new_zeros(scores.shape[0], len(self.CLASSES) + 1) nms_scores = scores.new_zeros(scores.shape[0], len(self.classes) + 1)
indices = labels.new_tensor(list(range(scores.shape[0]))) indices = labels.new_tensor(list(range(scores.shape[0])))
nms_scores[indices, labels] = scores nms_scores[indices, labels] = scores
lidar_boxes3d_for_nms = xywhr2xyxyr(lidar_boxes3d.bev) lidar_boxes3d_for_nms = xywhr2xyxyr(lidar_boxes3d.bev)
......
...@@ -708,9 +708,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -708,9 +708,9 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
and masks. Defaults to 0.3. and masks. Defaults to 0.3.
step (int): Global step value to record. Defaults to 0. step (int): Global step value to record. Defaults to 0.
""" """
classes = self.dataset_meta.get('CLASSES', None) classes = self.dataset_meta.get('classes', None)
# For object detection datasets, no PALETTE is saved # For object detection datasets, no palette is saved
palette = self.dataset_meta.get('PALETTE', None) palette = self.dataset_meta.get('palette', None)
ignore_index = self.dataset_meta.get('ignore_index', None) ignore_index = self.dataset_meta.get('ignore_index', None)
gt_data_3d = None gt_data_3d = None
......
...@@ -54,7 +54,7 @@ def test_getitem(): ...@@ -54,7 +54,7 @@ def test_getitem():
img='training/image_2', img='training/image_2',
), ),
pipeline=pipeline, pipeline=pipeline,
metainfo=dict(CLASSES=classes), metainfo=dict(classes=classes),
modality=modality) modality=modality)
kitti_dataset.prepare_data(0) kitti_dataset.prepare_data(0)
...@@ -94,7 +94,7 @@ def test_getitem(): ...@@ -94,7 +94,7 @@ def test_getitem():
img='training/image_2', img='training/image_2',
), ),
pipeline=pipeline, pipeline=pipeline,
metainfo=dict(CLASSES=['Car']), metainfo=dict(classes=['Car']),
modality=modality) modality=modality)
input_dict = car_kitti_dataset.get_data_info(0) input_dict = car_kitti_dataset.get_data_info(0)
...@@ -105,4 +105,4 @@ def test_getitem(): ...@@ -105,4 +105,4 @@ def test_getitem():
assert ann_info['gt_labels_3d'].dtype == np.int64 assert ann_info['gt_labels_3d'].dtype == np.int64
# all instance have been filtered by classes # all instance have been filtered by classes
assert len(ann_info['gt_labels_3d']) == 0 assert len(ann_info['gt_labels_3d']) == 0
assert len(car_kitti_dataset.metainfo['CLASSES']) == 1 assert len(car_kitti_dataset.metainfo['classes']) == 1
...@@ -48,7 +48,7 @@ def test_getitem(): ...@@ -48,7 +48,7 @@ def test_getitem():
ann_file, ann_file,
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
metainfo=dict(CLASSES=classes), metainfo=dict(classes=classes),
modality=modality) modality=modality)
lyft_dataset.prepare_data(0) lyft_dataset.prepare_data(0)
...@@ -68,4 +68,4 @@ def test_getitem(): ...@@ -68,4 +68,4 @@ def test_getitem():
assert 'gt_bboxes_3d' in ann_info assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes) assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(lyft_dataset.metainfo['CLASSES']) == 9 assert len(lyft_dataset.metainfo['classes']) == 9
...@@ -51,7 +51,7 @@ def test_getitem(): ...@@ -51,7 +51,7 @@ def test_getitem():
ann_file=ann_file, ann_file=ann_file,
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
metainfo=dict(CLASSES=classes), metainfo=dict(classes=classes),
modality=modality) modality=modality)
nus_dataset.prepare_data(0) nus_dataset.prepare_data(0)
...@@ -77,7 +77,7 @@ def test_getitem(): ...@@ -77,7 +77,7 @@ def test_getitem():
assert 'gt_bboxes_3d' in ann_info assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes) assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(nus_dataset.metainfo['CLASSES']) == 10 assert len(nus_dataset.metainfo['classes']) == 10
assert input_dict['token'] == 'fd8420396768425eabec9bdddf7e64b6' assert input_dict['token'] == 'fd8420396768425eabec9bdddf7e64b6'
assert input_dict['timestamp'] == 1533201470.448696 assert input_dict['timestamp'] == 1533201470.448696
...@@ -67,7 +67,7 @@ class TestS3DISDataset(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestS3DISDataset(unittest.TestCase):
s3dis_seg_dataset = S3DISSegDataset( s3dis_seg_dataset = S3DISSegDataset(
data_root, data_root,
ann_file, ann_file,
metainfo=dict(CLASSES=classes, PALETTE=palette), metainfo=dict(classes=classes, palette=palette),
data_prefix=data_prefix, data_prefix=data_prefix,
pipeline=pipeline, pipeline=pipeline,
modality=modality, modality=modality,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment