naturalqs.py 1.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from . common import HFTask

class NaturalQs(HFTask):
    DATASET_PATH = "natural_questions"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def fewshot_description(self):
        # TODO: figure out description
        return ""

20
21
22
23
    def training_docs(self):
        # Cache training for faster few-shot.
        # Data is too large to fit in memory.

Charles Foster's avatar
Charles Foster committed
24
        return self.data["train"]
25

26
27
28
29
30
    def doc_to_text(self, doc, include_target=True):
        question = doc['question']['text']
        short_answer = doc['annotations']['short_answers'][0]['text']
        long_answer_start = doc['annotations']['long_answer'][0]['start_token']
        long_answer_end = doc['annotations']['long_answer'][0]['end_token']
31
        long_answer = " ".join(doc['document']['tokens']['token'][long_answer_start:long_answer_end])
32
33
34
35
        
        text = 'Q: ' + question + '\n\n' + 'A: '

        if include_target:
36
37
            # There's a short answer and a long answer. Based on the paper, I'm using the long answer.
            text += long_answer # Replace with short_answer[0] for short answer
38
39
40
41
42
43

        return text

    def evaluate(self, docs, lm, provide_description, num_fewshot):
        # TODO: implement
        raise NotImplementedError()