rec_metric.py 2.53 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Max Bachmann's avatar
Max Bachmann committed
15
from rapidfuzz.distance import Levenshtein
tink2123's avatar
tink2123 committed
16
import string
WenmuZhou's avatar
WenmuZhou committed
17
18
19


class RecMetric(object):
andyjpaddle's avatar
andyjpaddle committed
20
21
22
23
24
    def __init__(self,
                 main_indicator='acc',
                 is_filter=False,
                 ignore_space=True,
                 **kwargs):
WenmuZhou's avatar
WenmuZhou committed
25
        self.main_indicator = main_indicator
tink2123's avatar
tink2123 committed
26
        self.is_filter = is_filter
andyjpaddle's avatar
andyjpaddle committed
27
        self.ignore_space = ignore_space
WenmuZhou's avatar
add eps  
WenmuZhou committed
28
        self.eps = 1e-5
WenmuZhou's avatar
WenmuZhou committed
29
30
        self.reset()

tink2123's avatar
tink2123 committed
31
32
33
34
35
    def _normalize_text(self, text):
        text = ''.join(
            filter(lambda x: x in (string.digits + string.ascii_letters), text))
        return text.lower()

WenmuZhou's avatar
WenmuZhou committed
36
37
38
39
40
41
    def __call__(self, pred_label, *args, **kwargs):
        preds, labels = pred_label
        correct_num = 0
        all_num = 0
        norm_edit_dis = 0.0
        for (pred, pred_conf), (target, _) in zip(preds, labels):
andyjpaddle's avatar
andyjpaddle committed
42
43
44
            if self.ignore_space:
                pred = pred.replace(" ", "")
                target = target.replace(" ", "")
tink2123's avatar
tink2123 committed
45
46
47
            if self.is_filter:
                pred = self._normalize_text(pred)
                target = self._normalize_text(target)
Max Bachmann's avatar
Max Bachmann committed
48
            norm_edit_dis += Levenshtein.normalized_distance(pred, target)
WenmuZhou's avatar
WenmuZhou committed
49
50
51
52
53
54
55
            if pred == target:
                correct_num += 1
            all_num += 1
        self.correct_num += correct_num
        self.all_num += all_num
        self.norm_edit_dis += norm_edit_dis
        return {
WenmuZhou's avatar
add eps  
WenmuZhou committed
56
57
            'acc': correct_num / (all_num + self.eps),
            'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
WenmuZhou's avatar
WenmuZhou committed
58
59
60
61
        }

    def get_metric(self):
        """
zhoujun's avatar
zhoujun committed
62
        return metrics {
WenmuZhou's avatar
WenmuZhou committed
63
64
65
66
                 'acc': 0,
                 'norm_edit_dis': 0,
            }
        """
WenmuZhou's avatar
add eps  
WenmuZhou committed
67
68
        acc = 1.0 * self.correct_num / (self.all_num + self.eps)
        norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
WenmuZhou's avatar
WenmuZhou committed
69
70
71
72
73
74
75
        self.reset()
        return {'acc': acc, 'norm_edit_dis': norm_edit_dis}

    def reset(self):
        self.correct_num = 0
        self.all_num = 0
        self.norm_edit_dis = 0