gem_xsum.py 4.72 KB
Newer Older
ken's avatar
ken committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Don’t Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
https://arxiv.org/pdf/1808.08745.pdf

The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?". 

This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf


_CITATION = """
@InProceedings{xsum-emnlp,
  author =      "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
  title =       "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
  booktitle =   "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
  year =        "2018",
  address =     "Brussels, Belgium",
}
"""



class GEMXSUMBase(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "GEM/xsum"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def stopping_criteria(self):
        return '.'
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

class GEMXSUM(GEMXSUMBase):
    '''this is for train/validation/test'''

class GEMXSUMChallgeSample(GEMXSUMBase):
    '''this is for challenge_train_sample/challenge_validation_sample'''

    def has_test_docs(self):
        return False
        
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["challenge_train_sample"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["challenge_validation_sample"]

class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
    '''this is for challenge_test_backtranslation'''

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["challenge_test_backtranslation"]

class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
    '''this is for challenge_test_bfp_02'''

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["challenge_test_bfp_02"]

class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
    '''this is for challenge_test_bfp_05'''

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["challenge_test_bfp_05"]

class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
    '''this is for challenge_test_nopunc'''

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["challenge_test_nopunc"]

class GEMXSUMChallgeTestCovid(GEMXSUMBase):
    '''this is for challenge_test_covid'''

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["challenge_test_covid"]