naturalqs.py 1.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from . common import HFTask

class NaturalQs(HFTask):
    DATASET_PATH = "natural_questions"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def fewshot_description(self):
        # TODO: figure out description
        return ""

    def doc_to_text(self, doc, include_target=True):
        question = doc['question']['text']
        short_answer = doc['annotations']['short_answers'][0]['text']
        long_answer_start = doc['annotations']['long_answer'][0]['start_token']
        long_answer_end = doc['annotations']['long_answer'][0]['end_token']
        passage = " ".join(doc['document']['tokens']['token'][long_answer_start:long_answer_end])
        
        text = 'Q: ' + question + '\n\n' + 'A: '

        if include_target:
            # What if there is no short answer? This will be an empty string. Currently, default to the long answer otherwise.
            if short_answer:
                text += short_answer[0]
            else:
                text += long_answer

        return text

    def evaluate(self, docs, lm, provide_description, num_fewshot):
        # TODO: implement
        raise NotImplementedError()