panoptic.py 8.03 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ------------------------------------------------------------------------------
# Reference: https://github.com/mcordts/cityscapesScripts/blob/aeb7b82531f86185ce287705be28f452ba3ddbb8/cityscapesscripts/evaluation/evalPanopticSemanticLabeling.py
# Modified by Guowei Chen
# ------------------------------------------------------------------------------

from collections import defaultdict, OrderedDict

import numpy as np

OFFSET = 256 * 256 * 256


class PQStatCat():
    def __init__(self):
        self.iou = 0.0
        self.tp = 0
        self.fp = 0
        self.fn = 0

    def __iadd__(self, pd_stat_cat):
        self.iou += pd_stat_cat.iou
        self.tp += pd_stat_cat.tp
        self.fp += pd_stat_cat.fp
        self.fn += pd_stat_cat.fn
        return self

    def __repr__(self):
        s = 'iou: ' + str(self.iou) + ' tp: ' + str(self.tp) + ' fp: ' + str(
            self.fp) + ' fn: ' + str(self.fn)
        return s


class PQStat():
    def __init__(self, num_classes):
        self.pq_per_cat = defaultdict(PQStatCat)
        self.num_classes = num_classes

    def __getitem__(self, i):
        return self.pq_per_cat[i]

    def __iadd__(self, pd_stat):
        for label, pq_stat_cat in pd_stat.pq_per_cat.items():
            self.pd_per_cat[label] += pq_stat_cat
        return self

    def pq_average(self, isthing=None, thing_list=None):
        """
        Calculate the average pq for all and every class.

        Args:
            num_classes (int): number of classes.
            isthing (bool|None, optional): calculate average pq for thing class if isthing is True,
                for stuff class if isthing is False and for all if isthing is None. Default: None. Default: None.
            thing_list (list|None, optional): A list of thing class. It should be provided when isthing is equal to True or False. Default: None.
        """
        pq, sq, rq, n = 0, 0, 0, 0
        per_class_results = {}
        for label in range(self.num_classes):
            if isthing is not None:
                if isthing:
                    if label not in thing_list:
                        continue
                else:
                    if label in thing_list:
                        continue
            iou = self.pq_per_cat[label].iou
            tp = self.pq_per_cat[label].tp
            fp = self.pq_per_cat[label].fp
            fn = self.pq_per_cat[label].fn
            if tp + fp + fn == 0:
                per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0}
                continue
            n += 1
            pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
            sq_class = iou / tp if tp != 0 else 0
            rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)

            per_class_results[label] = {
                'pq': pq_class,
                'sq': sq_class,
                'rq': rq_class
            }
            pq += pq_class
            sq += sq_class
            rq += rq_class

        return {
            'pq': pq / n,
            'sq': sq / n,
            'rq': rq / n,
            'n': n
        }, per_class_results


class PanopticEvaluator:
    """
    Evaluate semantic segmentation
    """

    def __init__(self,
                 num_classes,
                 thing_list,
                 ignore_index=255,
                 label_divisor=1000):
        self.pq_stat = PQStat(num_classes)
        self.num_classes = num_classes
        self.thing_list = thing_list
        self.ignore_index = ignore_index
        self.label_divisor = label_divisor

    def update(self, pred, gt):
        # get the labels and counts for the pred and gt.
        gt_labels, gt_labels_counts = np.unique(gt, return_counts=True)
        pred_labels, pred_labels_counts = np.unique(pred, return_counts=True)
        gt_segms = defaultdict(dict)
        pred_segms = defaultdict(dict)
        for label, label_count in zip(gt_labels, gt_labels_counts):
            category_id = label // self.label_divisor if label > self.label_divisor else label
            gt_segms[label]['area'] = label_count
            gt_segms[label]['category_id'] = category_id
            gt_segms[label]['iscrowd'] = 1 if label in self.thing_list else 0
        for label, label_count in zip(pred_labels, pred_labels_counts):
            category_id = label // self.label_divisor if label > self.label_divisor else label
            pred_segms[label]['area'] = label_count
            pred_segms[label]['category_id'] = category_id

        # confusion matrix calculation
        pan_gt_pred = gt.astype(np.uint64) * OFFSET + pred.astype(np.uint64)
        gt_pred_map = {}
        labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
        for label, intersection in zip(labels, labels_cnt):
            gt_id = label // OFFSET
            pred_id = label % OFFSET
            gt_pred_map[(gt_id, pred_id)] = intersection

        # count all matched pairs
        gt_matched = set()
        pred_matched = set()
        for label_tuple, intersection in gt_pred_map.items():
            gt_label, pred_label = label_tuple
            if gt_label == self.ignore_index or pred_label == self.ignore_index:
                continue
            if gt_segms[gt_label]['iscrowd'] == 1:
                continue
            if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][
                    'category_id']:
                continue
            union = pred_segms[pred_label]['area'] + gt_segms[gt_label][
                'area'] - intersection - gt_pred_map.get(
                    (self.ignore_index, pred_label), 0)
            iou = intersection / union
            if iou > 0.5:
                self.pq_stat[gt_segms[gt_label]['category_id']].tp += 1
                self.pq_stat[gt_segms[gt_label]['category_id']].iou += iou
                gt_matched.add(gt_label)
                pred_matched.add(pred_label)

        # count false negtive
        crowd_labels_dict = {}
        for gt_label, gt_info in gt_segms.items():
            if gt_label in gt_matched:
                continue
            if gt_label == self.ignore_index:
                continue
            # ignore crowd
            if gt_info['iscrowd'] == 1:
                crowd_labels_dict[gt_info['category_id']] = gt_label
                continue
            self.pq_stat[gt_info['category_id']].fn += 1

        # count false positive
        for pred_label, pred_info in pred_segms.items():
            if pred_label in pred_matched:
                continue
            if pred_label == self.ignore_index:
                continue
            # intersection of the segment with self.ignore_index
            intersection = gt_pred_map.get((self.ignore_index, pred_label), 0)
            if pred_info['category_id'] in crowd_labels_dict:
                intersection += gt_pred_map.get(
                    (crowd_labels_dict[pred_info['category_id']], pred_label),
                    0)
            # predicted segment is ignored if more than half of the segment correspond to self.ignore_index regions
            if intersection / pred_info['area'] > 0.5:
                continue
            self.pq_stat[pred_info['category_id']].fp += 1

    def evaluate(self):
        metrics = [("All", None), ("Things", True), ("Stuff", False)]
        results = {}
        for name, isthing in metrics:
            results[name], per_class_results = self.pq_stat.pq_average(
                isthing=isthing, thing_list=self.thing_list)
            if name == 'All':
                results['per_class'] = per_class_results
        return OrderedDict(pan_seg=results)