class_names.py 784 Bytes
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
import mmcv

zhangwenwei's avatar
zhangwenwei committed
3
from mmdet.core.evaluation import dataset_aliases
zhangwenwei's avatar
zhangwenwei committed
4
5
6
7
8
9
10
11
12
13
14
15


def kitti_classes():
    return [
        'Car',
        'Pedestrian',
        'Cyclist',
        'Van',
        'Person_sitting',
    ]


zhangwenwei's avatar
zhangwenwei committed
16
dataset_aliases.update({'kitti': ['KITTI', 'kitti']})
zhangwenwei's avatar
zhangwenwei committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def get_classes(dataset):
    """Get class names of a dataset."""
    alias2name = {}
    for name, aliases in dataset_aliases.items():
        for alias in aliases:
            alias2name[alias] = name

    if mmcv.is_str(dataset):
        if dataset in alias2name:
            labels = eval(alias2name[dataset] + '_classes()')
        else:
            raise ValueError('Unrecognized dataset: {}'.format(dataset))
    else:
        raise TypeError('dataset must a str, but got {}'.format(type(dataset)))
    return labels