utils.py 4.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import collections
import re

from lm_eval.utils import weighted_f1_score


def doc_to_target(doc):
    return transform_text(doc["ner_tags"])


def transform_text(text):
    entities = []
    current_entity = ""
    current_tag = ""

    for pair in text.split("\n"):
        if pair:  # Check if the line is not empty
            word, tag = pair.strip().split(": ")
            tag = tag.upper()
            word = word.lower()
            word = word.strip(",.").strip()

            if tag.startswith("B-"):
                if current_entity:
                    entities.append(f"{current_tag}: {current_entity}")
                current_tag = tag.split("-")[1]
                current_entity = word
            elif tag.startswith("I-") and tag.split("-")[1] == current_tag:
                current_entity += word
            else:
                if current_entity:
                    entities.append(f"{current_tag}: {current_entity}")
                    current_entity = ""
                    current_tag = ""
    if current_entity:
        entities.append(f"{current_tag}: {current_entity}")

        # Join all the transformed output lines with $$ as separator
    return " $$ ".join(entities)


def span_f1_agg(items):
    """Computes Span based F1 score.

    This function is copied from
    https://github.com/google-research/multilingual-t5/blob/master/multilingual_t5/evaluation/metrics.py

    Args:
    targets: list of strings or list of list of strings if multiple references
      are present.
    predictions: list of strings

    Returns:
    span f1 across all targets and predictions (Based on CoNLL script)
    """
    unzipped_list = list(zip(*items))
    targets = unzipped_list[0]
    predictions = unzipped_list[1]

    true_positives = collections.defaultdict(int)
    false_positives = collections.defaultdict(int)
    false_negatives = collections.defaultdict(int)

    def normalize_text(strings):
        def get_blank_spaces_pattern():
            return re.compile(r"\s{3,}|\t")

        def remove_blank_spaces(text):
            text = re.sub(pattern=get_blank_spaces_pattern(), repl="", string=text)
70
            text = re.sub(r"\s+", " ", text)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
            return text

        def remove_punctuation(text):
            my_punctuation = '!"$%&\'()*+,-./:;<=>?[\\]^_`{|}~•@.""-,`'
            text = re.sub(
                "[" + my_punctuation + "]+", " ", str(text)
            )  # strip punctuation
            return text

        def remove_articles(text):
            regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
            return re.sub(regex, " ", text)

        def lowercase(text):
            text = text.lower()
            return text

        strings = remove_punctuation(strings)
        strings = remove_articles(strings)
        strings = remove_blank_spaces(strings)
        strings = lowercase(strings)

        return strings

    def tags_to_spans(tag_sequence, delimiter="$$"):
        """Extract spans from IOB1 or BIO tags."""
        if isinstance(tag_sequence, list):
            tag_sequence = " ".join(i.strip() for i in tag_sequence)
        tag_sequence_split = [
            item.strip()
            for sub in tag_sequence.strip().split(delimiter)
            for item in sub.split("$")
            if item
        ]
        tag_sequence_split = [
            item.strip()
            for value in tag_sequence_split
            for sub in value.split(". ")
            for item in sub.split(", ")
        ]
        tags_entities = []
        for tag_entity in tag_sequence_split:
            tag_entity_split = tag_entity.split(": ")
            if len(tag_entity_split) != 2:
                continue
            tag = normalize_text(tag_entity_split[0].strip())
            entity = normalize_text(tag_entity_split[1].rstrip().lstrip())
            tags_entities.append((tag, entity))
        return tags_entities

    def compute_f1_metrics(true_positive, false_positive, false_negative):
        precision = float(true_positive) / float(true_positive + false_positive + 1e-13)
        recall = float(true_positive) / float(true_positive + false_negative + 1e-13)
        f1_measures = 2.0 * ((precision * recall) / (precision + recall + 1e-13))
        return precision, recall, f1_measures

    for target, pred in zip(targets, predictions):
        gold_spans = tags_to_spans(target)
        predicted_spans = tags_to_spans(pred)

        for span in predicted_spans:
            if span in gold_spans:
                true_positives[span[0]] += 1
                gold_spans.remove(span)
            else:
                false_positives[span[0]] += 1
        # These spans weren't predicted.
        for span in gold_spans:
            false_negatives[span[0]] += 1

    _, _, f1_measure = compute_f1_metrics(
        sum(true_positives.values()),
        sum(false_positives.values()),
        sum(false_negatives.values()),
    )
    return f1_measure