t5_utils.py 3.36 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
import re
from lm_eval.api.filter import Filter

lintangsutawika's avatar
lintangsutawika committed
4

lintangsutawika's avatar
lintangsutawika committed
5
def doc_to_text(x):
lintangsutawika's avatar
lintangsutawika committed
6
7
    text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
    return "wsc: " + text
lintangsutawika's avatar
lintangsutawika committed
8
9
10
11


def _wsc_inputs(x):

lintangsutawika's avatar
lintangsutawika committed
12
    words = x["text"].split(" ")
lintangsutawika's avatar
lintangsutawika committed
13
14
15
16

    # We would need some special logic to handle the case where the pronoun is the
    # first or last word in the text. None of the examples in WSC seem to have
    # this, so we are ignoring these cases.
lintangsutawika's avatar
lintangsutawika committed
17
18
19
    assert x["span2_index"] > 0
    assert x["span2_index"] < len(words)
    pronoun_index = x["span2_index"]
lintangsutawika's avatar
lintangsutawika committed
20
21

    def create_input():
lintangsutawika's avatar
lintangsutawika committed
22
23
24
25
26
27
28
29
30
        assert words[pronoun_index] == x["span2_text"]

        return " ".join(
            [
                " ".join(words[:pronoun_index]),
                "X",
                " ".join(words[pronoun_index + 1 :]),
            ]
        )
lintangsutawika's avatar
lintangsutawika committed
31
32

    # Handle some special cases.
lintangsutawika's avatar
lintangsutawika committed
33
34
35
36
    if (
        x["text"]
        == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
    ):
lintangsutawika's avatar
lintangsutawika committed
37
        return (
lintangsutawika's avatar
lintangsutawika committed
38
            "The boy continued to whip the pony , and eventually the pony threw "
lintangsutawika's avatar
lintangsutawika committed
39
40
41
42
            'him over. John laughed out quite loud. "Good for X ," he said.'
        )

    # Using the span2_index, we get 'use' instead of 'it'.
lintangsutawika's avatar
lintangsutawika committed
43
44
45
46
    if (
        x["text"]
        == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
    ):
lintangsutawika's avatar
lintangsutawika committed
47
        return (
lintangsutawika's avatar
lintangsutawika committed
48
49
50
            "When they had eventually calmed down a bit , and had gotten home, "
            "Mr. Farley put the magic pebble in an iron safe . Some day they might "
            "want to use X , but really for now, what more could they wish for?"
lintangsutawika's avatar
lintangsutawika committed
51
52
53
54
55
56
57
58
59
        )

    return create_input()


class WSCPostprocess(Filter):
    def __init__(self, **kwargs):

        self.determiners = {
lintangsutawika's avatar
lintangsutawika committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            "a",
            "an",
            "few",
            "her",
            "his",
            "each",
            "every",
            "many",
            "much",
            "my",
            "our",
            "some",
            "that",
            "the",
            "their",
            "these",
            "this",
            "those",
            "which",
            "whose",
            "your",
lintangsutawika's avatar
lintangsutawika committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        }

    def clean(self, s):
        """Ignore capitalization and determiners."""
        s = s.strip().lower()
        return " ".join([w for w in s.split(" ") if w not in self.determiners])

    def apply(self, resps, docs):

        filtered_resps = []
        for prediction, reference in zip(*(resps, docs["span1_text"])):

            prediction = self.clean(prediction[0])
            reference = self.clean(reference)

            if ("'" in prediction) != ("'" in reference):
                # referent is "Bob's hat" as predicting the referent.
                predicted_referent = False
            else:
                prediction_words = set(prediction.split(" "))
                referent_words = set(reference.split(" "))

                # Handle cases where the prediction is "fuzzy bunny" and the referent is
                # "bunny".
                predicted_referent = prediction_words.issubset(
lintangsutawika's avatar
lintangsutawika committed
106
107
                    referent_words
                ) or referent_words.issubset(prediction_words)
lintangsutawika's avatar
lintangsutawika committed
108
109
110
111

            filtered_resps.append(predicted_referent)

        return filtered_resps