naturalqs.py 2.16 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.

3
from . common import HFTask
Leo Gao's avatar
Leo Gao committed
4
from itertools import islice
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

class NaturalQs(HFTask):
    DATASET_PATH = "natural_questions"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def fewshot_description(self):
        # TODO: figure out description
        return ""

23
24
25
    def training_docs(self):
        # Cache training for faster few-shot.
        # Data is too large to fit in memory.
Charles Foster's avatar
Charles Foster committed
26
        return self.data["train"]
27

Leo Gao's avatar
Leo Gao committed
28
29
30
31
32
33
34
    def fewshot_examples(self, k):
        # Data is too large to fit in memory. We just sample from the first bit.
        if self._traindocs is None:
            self._traindocs = list(islice(self.training_docs(), 0, 100000))

        return random.sample(self._traindocs, k)

35
36
37
38
39
40
41
42
43
44
45
46
47
    def doc_to_text(self, doc):
        return 'Q: ' + doc['question']['text'] + '\n\n' + 'A: '

    def doc_to_target(self, doc):
        # There's a short answer and a long answer. Based on the paper, I'm using the long answer.
        short_answer = doc['annotations']['short_answers'][0]['text']
        long_answer_start = doc['annotations']['long_answer'][0]['start_token']
        long_answer_end = doc['annotations']['long_answer'][0]['end_token']
        long_answer_span = doc['document']['tokens']['token'][long_answer_start:long_answer_end]
        long_answer_is_html = doc['document']['tokens']['is_html'][long_answer_start:long_answer_end]
        long_answer_chars = [tok for (tok, is_html) in zip(long_answer_span, long_answer_is_html) if not is_html]
        long_answer = " ".join(long_answer_chars)
        return long_answer # Replace with short_answer[0] for short answer
48

49
50
51
52
53
    # TODO: Implement evaluation code

    # ***IMPORTANT***: this evaluation function needs to be written for the new framework. 
    # For more info, check out the interface in base.py and the example BoolQ implementation in superglue.py. 
    # Remove this comment when the evaluation code is implemented.