gem_xsum.py 4.87 KB
Newer Older
ken's avatar
ken committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Don’t Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
https://arxiv.org/pdf/1808.08745.pdf

The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?". 

This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf


_CITATION = """
@InProceedings{xsum-emnlp,
  author =      "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
  title =       "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
  booktitle =   "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
  year =        "2018",
  address =     "Brussels, Belgium",
}
"""



class GEMXSUMBase(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "GEM/xsum"
    DATASET_NAME = None
ken's avatar
ken committed
34
    SPLIT = None
ken's avatar
ken committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def stopping_criteria(self):
        return '.'
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

class GEMXSUM(GEMXSUMBase):
    '''this is for train/validation/test'''
ken's avatar
ken committed
66
    SPLIT = ''
ken's avatar
ken committed
67
68
69

class GEMXSUMChallgeSample(GEMXSUMBase):
    '''this is for challenge_train_sample/challenge_validation_sample'''
ken's avatar
ken committed
70
    SPLIT = 'challenge_sample'
ken's avatar
ken committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    def has_test_docs(self):
        return False
        
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["challenge_train_sample"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["challenge_validation_sample"]

class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
    '''this is for challenge_test_backtranslation'''
ken's avatar
ken committed
90
    SPLIT = 'challenge_test_backtranslation'
ken's avatar
ken committed
91
92
93
94
95
96
97
98
99

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
100
            return self.dataset[SPLIT]
ken's avatar
ken committed
101
102
103

class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
    '''this is for challenge_test_bfp_02'''
ken's avatar
ken committed
104
    SPLIT = 'challenge_test_bfp_02'
ken's avatar
ken committed
105
106
107
108
109
110
111
112
113

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
114
            return self.dataset[SPLIT]
ken's avatar
ken committed
115
116
117

class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
    '''this is for challenge_test_bfp_05'''
ken's avatar
ken committed
118
    SPLIT = 'challenge_test_bfp_05'
ken's avatar
ken committed
119
120
121
122
123
124
125
126

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
127
            return self.dataset[SPLIT]
ken's avatar
ken committed
128
129
130

class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
    '''this is for challenge_test_nopunc'''
ken's avatar
ken committed
131
    SPLIT = 'challenge_test_nopunc'
ken's avatar
ken committed
132
133
134
135
136
137
138
139

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
140
            return self.dataset[SPLIT]
ken's avatar
ken committed
141
142
143

class GEMXSUMChallgeTestCovid(GEMXSUMBase):
    '''this is for challenge_test_covid'''
ken's avatar
ken committed
144
    SPLIT = 'challenge_test_covid'
ken's avatar
ken committed
145
146
147
148
149
150
151
152

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
153
            return self.dataset[SPLIT]