base_dataset.py 5.7 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
import os
import os.path as osp
zhe chen's avatar
zhe chen committed
3
import warnings
yeshenglong1's avatar
yeshenglong1 committed
4

zhe chen's avatar
zhe chen committed
5
6
import mmcv
import numpy as np
yeshenglong1's avatar
yeshenglong1 committed
7
8
9
10
from mmdet3d.datasets.pipelines import Compose
from mmdet.datasets import DATASETS
from torch.utils.data import Dataset

zhe chen's avatar
zhe chen committed
11
12
13
14
from .evaluation.vector_eval import VectorEvaluate

warnings.filterwarnings('ignore')

yeshenglong1's avatar
yeshenglong1 committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

@DATASETS.register_module()
class BaseMapDataset(Dataset):
    """Map dataset base class.

    Args:
        ann_file (str): annotation file path
        cat2id (dict): category to class id
        roi_size (tuple): bev range
        eval_config (Config): evaluation config
        meta (dict): meta information
        pipeline (Config): data processing pipeline config,
        interval (int): annotation load interval
        work_dir (str): path to work dir
        test_mode (bool): whether in test mode
    """
zhe chen's avatar
zhe chen committed
31
32

    def __init__(self,
yeshenglong1's avatar
yeshenglong1 committed
33
34
35
36
37
38
39
40
41
                 ann_file,
                 root_path,
                 cat2id,
                 roi_size,
                 meta,
                 pipeline,
                 interval=1,
                 work_dir=None,
                 test_mode=False,
zhe chen's avatar
zhe chen committed
42
                 ):
yeshenglong1's avatar
yeshenglong1 committed
43
44
45
46
        super().__init__()
        self.ann_file = ann_file
        self.meta = meta
        self.root_path = root_path
zhe chen's avatar
zhe chen committed
47

yeshenglong1's avatar
yeshenglong1 committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        self.classes = list(cat2id.keys())
        self.num_classes = len(self.classes)
        self.cat2id = cat2id
        self.interval = interval

        self.load_annotations(self.ann_file)
        self.idx2token = {}
        for i, s in enumerate(self.samples):
            if 'timestamp' in s:
                self.idx2token[i] = s['timestamp']
            else:
                self.idx2token[i] = s['token']
        self.token2idx = {v: k for k, v in self.idx2token.items()}

        if pipeline is not None:
            self.pipeline = Compose(pipeline)
        else:
            self.pipeline = None
zhe chen's avatar
zhe chen committed
66

yeshenglong1's avatar
yeshenglong1 committed
67
68
69
70
        # dummy flags to fit with mmdet dataset
        self.flag = np.zeros(len(self), dtype=np.uint8)

        self.roi_size = roi_size
zhe chen's avatar
zhe chen committed
71

yeshenglong1's avatar
yeshenglong1 committed
72
73
74
75
76
77
78
79
80
81
82
        self.work_dir = work_dir
        self.test_mode = test_mode

    def load_annotations(self, ann_file):
        raise NotImplementedError

    def get_sample(self, idx):
        raise NotImplementedError

    def format_results(self, results, denormalize=True, prefix=None):
        '''Format prediction result to submission format.
zhe chen's avatar
zhe chen committed
83

yeshenglong1's avatar
yeshenglong1 committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        Args:
            results (list[Tensor]): List of prediction results.
            denormalize (bool): whether to denormalize prediction from (0, 1) \
                to bev range. Default: True
            prefix (str): work dir prefix to save submission file.

        Returns:
            dict: Evaluation results
        '''

        meta = self.meta
        submissions = {
            'meta': meta,
            'results': {},
        }

        for pred in results:
            '''
            For each case, the result should be formatted as Dict{'vectors': [], 'scores': [], 'labels': []}
            'vectors': List of vector, each vector is a array([[x1, y1], [x2, y2] ...]),
                contain all vectors predicted in this sample.
zhe chen's avatar
zhe chen committed
105
            'scores: List of score(float),
yeshenglong1's avatar
yeshenglong1 committed
106
                contain scores of all instances in this sample.
zhe chen's avatar
zhe chen committed
107
            'labels': List of label(int),
yeshenglong1's avatar
yeshenglong1 committed
108
109
                contain labels of all instances in this sample.
            '''
zhe chen's avatar
zhe chen committed
110
            if pred is None:  # empty prediction
yeshenglong1's avatar
yeshenglong1 committed
111
                continue
zhe chen's avatar
zhe chen committed
112

yeshenglong1's avatar
yeshenglong1 committed
113
114
115
            single_case = {'vectors': [], 'scores': [], 'labels': []}
            token = pred['token']
            roi_size = np.array(self.roi_size)
zhe chen's avatar
zhe chen committed
116
            origin = -np.array([self.roi_size[0] / 2, self.roi_size[1] / 2])
yeshenglong1's avatar
yeshenglong1 committed
117
118
119
120
121
122
123
124
125

            for i in range(len(pred['scores'])):
                score = pred['scores'][i]
                label = pred['labels'][i]
                vector = pred['vectors'][i]

                # A line should have >=2 points
                if len(vector) < 2:
                    continue
zhe chen's avatar
zhe chen committed
126

yeshenglong1's avatar
yeshenglong1 committed
127
128
129
130
131
132
133
                if denormalize:
                    eps = 2
                    vector = vector * (roi_size + eps) + origin

                single_case['vectors'].append(vector)
                single_case['scores'].append(score)
                single_case['labels'].append(label)
zhe chen's avatar
zhe chen committed
134

yeshenglong1's avatar
yeshenglong1 committed
135
            submissions['results'][token] = single_case
zhe chen's avatar
zhe chen committed
136

yeshenglong1's avatar
yeshenglong1 committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        out_path = osp.join(prefix, 'submission_vector.json')
        print(f'\nsaving submissions results to {out_path}')
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        mmcv.dump(submissions, out_path)
        return out_path

    def evaluate(self, results, logger=None, **kwargs):
        '''Evaluate prediction result based on `output_format` specified by dataset.

        Args:
            results (list[Tensor]): List of prediction results.
            logger (logger): logger to print evaluation results.

        Returns:
            dict: Evaluation results.
        '''

        output_format = self.meta['output_format']
        self.evaluator = VectorEvaluate(self.ann_file)

        print('len of the results', len(results))
zhe chen's avatar
zhe chen committed
158

yeshenglong1's avatar
yeshenglong1 committed
159
160
161
162
163
164
165
166
167
168
169
170
        result_path = self.format_results(results, denormalize=True, prefix=self.work_dir)

        result_dict = self.evaluator.evaluate(result_path, logger=logger)
        return result_dict

    def __len__(self):
        """Return the length of data infos.

        Returns:
            int: Length of data infos.
        """
        return len(self.samples)
zhe chen's avatar
zhe chen committed
171

yeshenglong1's avatar
yeshenglong1 committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def _rand_another(self, idx):
        """Randomly get another item.

        Returns:
            int: Another index of item.
        """
        return np.random.choice(self.__len__)

    def __getitem__(self, idx):
        """Get item from infos according to the given index.

        Returns:
            dict: Data dictionary of the corresponding index.
        """
        input_dict = self.get_sample(idx)
        data = self.pipeline(input_dict)
        return data