metrics.py 5.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import statistics

import jieba
from bert_score import score
from nltk.translate.bleu_score import sentence_bleu
from rouge_chinese import Rouge as Rouge_cn
from sklearn.metrics import f1_score, precision_score, recall_score


def bleu_score(preds: list, targets: list) -> dict:
    """Calculate BLEU Score Metric

    The calculation includes BLEU-1 for unigram, BLEU-2 for bigram,
    BLEU-3 for trigram and BLEU-4 for 4-gram. Unigram evaluates the
    accuracy in word level, other n-gram evaluate the fluency in
    sentence level.
    """
    bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
    cumulative_bleu = [0] * 4
    weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
               (1. / 4., 1. / 4., 1. / 4., 1. / 4.)]

    for pred, target in zip(preds, targets):
        pred_list = (' '.join(jieba.cut(pred))).split()
        target_list = [(' '.join(jieba.cut(target))).split()]

        bleu = sentence_bleu(target_list, pred_list, weights=weights)
        cumulative_bleu = [a + b for a, b in zip(cumulative_bleu, bleu)]

    for i in range(len(cumulative_bleu)):
        bleu_scores[f"bleu{i+1}"] = cumulative_bleu[i] / len(preds)

    return bleu_scores


def rouge_cn_score(preds: list, targets: list) -> dict:
    """Calculate Chinese ROUGE Score Metric

    The calculation includes ROUGE-1 for unigram, ROUGE-2 for bigram
    and ROUGE-L. ROUGE-N evaluates the number of matching n-grams between
    the preds and targets. ROUGE-L measures the number of matching
    longest common subsequence (LCS) between preds and targets.
    """
    rouge_scores = {"rouge1": {}, "rouge2": {}, "rougeL": {}}
    all_preds = []
    all_targets = []

    for pred, target in zip(preds, targets):
        pred_list = ' '.join(jieba.cut(pred))
        target_list = ' '.join(jieba.cut(target))
        all_preds.append(pred_list)
        all_targets.append(target_list)

    rouge_cn = Rouge_cn()
    rouge_avg = rouge_cn.get_scores(all_preds, all_targets, avg=True)

    rouge_scores["rouge1"] = rouge_avg["rouge-1"]["f"]
    rouge_scores["rouge2"] = rouge_avg["rouge-2"]["f"]
    rouge_scores["rougeL"] = rouge_avg["rouge-l"]["f"]

    return rouge_scores


def distinct_score(preds: list) -> dict:
    """Calculate Distinct Score Metric

    This metric refers to https://arxiv.org/abs/1510.03055.
    It evaluates the diversity of generation text by counting
    the unique n-grams.
    """
    distinct_score = {"distinct": 0}
    cumulative_distinct = []

    for pred in preds:
        pred_seg_list = list(' '.join(jieba.cut(pred)))
        count_segs = len(pred_seg_list)
        unique_segs = set(pred_seg_list)
        count_unique_chars = len(unique_segs)

        cumulative_distinct.append(count_unique_chars / count_segs)

    distinct_score["distinct"] = statistics.mean(cumulative_distinct)

    return distinct_score


def bert_score(preds: list, targets: list) -> dict:
    """Calculate BERTScore Metric

    The BERTScore evaluates the semantic similarity between
    tokens of preds and targets with BERT.
    """
    bert_score = {"bert_score": 0}
    pred_list = []
    target_list = []

    for pred, target in zip(preds, targets):
        pred_list.append(' '.join(jieba.cut(pred)))
        target_list.append(' '.join(jieba.cut(target)))

    _, _, F = score(pred_list, target_list, lang="zh", verbose=True)

    bert_score["bert_score"] = F.mean().item()

    return bert_score


def calculate_precision_recall_f1(preds: list, targets: list) -> dict:
    """Precision, Recall and F1-Score Calculation

    The calculation of precision, recall and f1-score is realized by counting
    the number f overlaps between the preds and target. The comparison length
    limited by the shorter one of preds and targets. This design is mainly
114
    considered for classification and extraction categories.
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    """
    precision_recall_f1 = {"precision": 0, "recall": 0, "f1_score": 0}
    precision_scores = []
    recall_scores = []
    f1_scores = []

    for pred, target in zip(preds, targets):
        pred_list = [char for char in pred]
        target_list = [char for char in target]

        target_labels = [1] * min(len(target_list), len(pred_list))
        pred_labels = [int(pred_list[i] == target_list[i]) for i in range(0, min(len(target_list), len(pred_list)))]

        precision_scores.append(precision_score(target_labels, pred_labels, zero_division=0))
        recall_scores.append(recall_score(target_labels, pred_labels, zero_division=0))
        f1_scores.append(f1_score(target_labels, pred_labels, zero_division=0))

    precision_recall_f1["precision"] = statistics.mean(precision_scores)
    precision_recall_f1["recall"] = statistics.mean(recall_scores)
    precision_recall_f1["f1_score"] = statistics.mean(f1_scores)

    return precision_recall_f1


def precision(preds: list, targets: list) -> dict:
    """Calculate Precision Metric
141
    (design for classification and extraction categories)
142
143
144
145
146
147
148
149
150
151

    Calculating precision by counting the number of overlaps between the preds and target.
    """
    precision = {"precision": 0}
    precision["precision"] = calculate_precision_recall_f1(preds, targets)["precision"]
    return precision


def recall(preds: list, targets: list) -> dict:
    """Calculate Recall Metric
152
    (design for classification and extraction categories)
153
154
155
156
157
158
159
160
161
162

    Calculating recall by counting the number of overlaps between the preds and target.
    """
    recall = {"recall": 0}
    recall["recall"] = calculate_precision_recall_f1(preds, targets)["recall"]
    return recall


def F1_score(preds: list, targets: list) -> dict:
    """Calculate F1-score Metric
163
    (design for classification and extraction categories)
164
165
166
167
168
169

    Calculating f1-score by counting the number of overlaps between the preds and target.
    """
    f1 = {"f1_score": 0}
    f1["f1_score"] = calculate_precision_recall_f1(preds, targets)["f1_score"]
    return f1