coqa.py 6.15 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.

3
import json
4
from lm_eval.base import Task, rf, mean
sdtblck's avatar
sdtblck committed
5
from ..utils import sh
6
from itertools import zip_longest
7
import transformers.data.metrics.squad_metrics as squad_metrics
thefazzer's avatar
thefazzer committed
8
9
10
11
12
13
14
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
15

16
class CoQA(Task):
thefazzer's avatar
thefazzer committed
17

sdtblck's avatar
sdtblck committed
18
    def download(self):
19
20
        pass
        # -N only overwrites if the remote file has changed
thefazzer's avatar
thefazzer committed
21
        sh ("""
sdtblck's avatar
sdtblck committed
22
            mkdir -p data/coqa 
23
24
            wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
            wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
sdtblck's avatar
sdtblck committed
25
26
            """)

27
28
29
30
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
31
        return True
Jason Phang's avatar
Jason Phang committed
32
33
34
35

    def has_test_docs(self):
        return False

36
    def training_docs(self):
thefazzer's avatar
thefazzer committed
37
38
39
40
41
        doc_data = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
        for doc in doc_data:
            for answer in doc['answers']:
                answer['input_text'] = self.get_answer_choice(answer['input_text'])
        return doc_data
42
43

    def validation_docs(self):
thefazzer's avatar
thefazzer committed
44
        return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
45
46

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
47
        pass
48
49
    
    def fewshot_description(self):
thefazzer's avatar
thefazzer committed
50
        return "Given a passage and a conversation so far, answer the next question in the conversation."
51
    
Leo Gao's avatar
Leo Gao committed
52
    def doc_to_text(self, doc):
thefazzer's avatar
thefazzer committed
53
54
        # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} 
        # and a question qi, the task is to predict the answer ai
55
        doc_text = doc["story"] + '\n\n'
thefazzer's avatar
thefazzer committed
56
        for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]):   # omit target answer ai
57
            question = f"Q: {q['input_text']}" + '\n\n'
58
            answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A: "
59
            doc_text += question + answer
60
            print(doc_text)
61
        return doc_text
thefazzer's avatar
thefazzer committed
62
        
63
64
    @classmethod
    def get_answers(cls, doc, turn_id):
thefazzer's avatar
thefazzer committed
65
        # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
66
67
68
69
        answers = []
        answer_forturn = doc["answers"][turn_id - 1]["input_text"]
        answers.append(answer_forturn)
        
thefazzer's avatar
thefazzer committed
70
71
72
73
74
        additional_answers = doc.get("additional_answers")
        if additional_answers:
            for key in additional_answers:
                additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
                if additional_answer_for_turn.lower() not in map(str.lower, answers):
75
76
                    answers.append(additional_answer_for_turn)
        return answers
thefazzer's avatar
thefazzer committed
77
    
thefazzer's avatar
thefazzer committed
78
79
80
81
82
83
84
85
86
87
88
89
90
    @classmethod
    def get_answer_choice(self, raw_text):
        # Function maps answers to CoQA answer categories
        # ~ 1/5 of the CoQA answers are Yes/No 
        # ~ 2/3 of the CoQA answers are span-based
        # (answers overlap with the passage ignoring punctuation and case mismatch)
        if raw_text == "unknown":
            return '0'
        if squad_metrics.normalize_answer(raw_text) == "yes":
            return '1'
        if squad_metrics.normalize_answer(raw_text) == "no":
            return '2'
        return '3' # Not a yes/no question
Leo Gao's avatar
Leo Gao committed
91

92
93
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
94
95
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
96
97
98
99
100
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
                gold_answers = gold_list[0:i] + gold_list[i + 1:]
thefazzer's avatar
thefazzer committed
101
                # predictions compared against (n) golds and take maximum
102
103
104
105
106
107
108
109
                em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

        return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}

thefazzer's avatar
thefazzer committed
110
111
112
113
114
115
116
117
    def doc_to_target(self, doc, turnid=None):
        # Default to prediction of last turn.
        if turnid is None:
            turnid = len(doc["questions"])

        raw_text = doc['answers'][turnid - 1]["input_text"]
        return self.get_answer_choice(raw_text)

Leo Gao's avatar
Leo Gao committed
118
119
120
121
122
123
124
125
126
127
128
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
thefazzer's avatar
thefazzer committed
129
130
131
132
133
134
        ll_requests = [
            rf.loglikelihood(ctx, " " + i)
            for i in ['0', '1', '2', '3']
        ]
        return ll_requests

Leo Gao's avatar
Leo Gao committed
135
136
137
138
139
140
141
142
143
144
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
145
        turn_id = len(doc["questions"])
thefazzer's avatar
thefazzer committed
146
147
        gold_list = [self.get_answer_choice(r_text) for r_text in self.get_answers(doc, turn_id)]
        pred = str(np.argmax(results))
148

thefazzer's avatar
thefazzer committed
149
        scores = self.compute_scores(gold_list, pred)
150

thefazzer's avatar
thefazzer committed
151
        return {
thefazzer's avatar
thefazzer committed
152
153
            "f1": scores['f1'],
            "em": scores['em'],
thefazzer's avatar
thefazzer committed
154
        }
155
156

    def higher_is_better(self):
157
        return {
158
159
            "f1": True,
            "em": True,
160
        }
Leo Gao's avatar
Leo Gao committed
161

162
    def aggregation(self):
163
        return {
164
165
            "f1": mean,
            "em": mean,
166
        }