t5_utils.py 3.35 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
import re
from lm_eval.api.filter import Filter

lintangsutawika's avatar
lintangsutawika committed
4

lintangsutawika's avatar
lintangsutawika committed
5
def doc_to_text(x):
lintangsutawika's avatar
lintangsutawika committed
6
7
    text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
    return "wsc: " + text
lintangsutawika's avatar
lintangsutawika committed
8
9
10


def _wsc_inputs(x):
lintangsutawika's avatar
lintangsutawika committed
11
    words = x["text"].split(" ")
lintangsutawika's avatar
lintangsutawika committed
12
13
14
15

    # We would need some special logic to handle the case where the pronoun is the
    # first or last word in the text. None of the examples in WSC seem to have
    # this, so we are ignoring these cases.
lintangsutawika's avatar
lintangsutawika committed
16
17
18
    assert x["span2_index"] > 0
    assert x["span2_index"] < len(words)
    pronoun_index = x["span2_index"]
lintangsutawika's avatar
lintangsutawika committed
19
20

    def create_input():
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
26
27
28
29
        assert words[pronoun_index] == x["span2_text"]

        return " ".join(
            [
                " ".join(words[:pronoun_index]),
                "X",
                " ".join(words[pronoun_index + 1 :]),
            ]
        )
lintangsutawika's avatar
lintangsutawika committed
30
31

    # Handle some special cases.
lintangsutawika's avatar
lintangsutawika committed
32
33
34
35
    if (
        x["text"]
        == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
    ):
lintangsutawika's avatar
lintangsutawika committed
36
        return (
lintangsutawika's avatar
lintangsutawika committed
37
            "The boy continued to whip the pony , and eventually the pony threw "
lintangsutawika's avatar
lintangsutawika committed
38
39
40
41
            'him over. John laughed out quite loud. "Good for X ," he said.'
        )

    # Using the span2_index, we get 'use' instead of 'it'.
lintangsutawika's avatar
lintangsutawika committed
42
43
44
45
    if (
        x["text"]
        == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
    ):
lintangsutawika's avatar
lintangsutawika committed
46
        return (
lintangsutawika's avatar
lintangsutawika committed
47
48
49
            "When they had eventually calmed down a bit , and had gotten home, "
            "Mr. Farley put the magic pebble in an iron safe . Some day they might "
            "want to use X , but really for now, what more could they wish for?"
lintangsutawika's avatar
lintangsutawika committed
50
51
52
53
54
55
56
57
        )

    return create_input()


class WSCPostprocess(Filter):
    def __init__(self, **kwargs):
        self.determiners = {
lintangsutawika's avatar
lintangsutawika committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            "a",
            "an",
            "few",
            "her",
            "his",
            "each",
            "every",
            "many",
            "much",
            "my",
            "our",
            "some",
            "that",
            "the",
            "their",
            "these",
            "this",
            "those",
            "which",
            "whose",
            "your",
lintangsutawika's avatar
lintangsutawika committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        }

    def clean(self, s):
        """Ignore capitalization and determiners."""
        s = s.strip().lower()
        return " ".join([w for w in s.split(" ") if w not in self.determiners])

    def apply(self, resps, docs):
        filtered_resps = []
        for prediction, reference in zip(*(resps, docs["span1_text"])):
            prediction = self.clean(prediction[0])
            reference = self.clean(reference)

            if ("'" in prediction) != ("'" in reference):
                # referent is "Bob's hat" as predicting the referent.
                predicted_referent = False
            else:
                prediction_words = set(prediction.split(" "))
                referent_words = set(reference.split(" "))

                # Handle cases where the prediction is "fuzzy bunny" and the referent is
                # "bunny".
                predicted_referent = prediction_words.issubset(
lintangsutawika's avatar
lintangsutawika committed
102
103
                    referent_words
                ) or referent_words.issubset(prediction_words)
lintangsutawika's avatar
lintangsutawika committed
104
105
106
107

            filtered_resps.append(predicted_referent)

        return filtered_resps