gem_xsum.py 4.9 KB
Newer Older
1

ken's avatar
ken committed
2
"""
3
Don't Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
ken's avatar
ken committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
https://arxiv.org/pdf/1808.08745.pdf

The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?". 

This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf


_CITATION = """
@InProceedings{xsum-emnlp,
  author =      "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
  title =       "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
  booktitle =   "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
  year =        "2018",
  address =     "Brussels, Belgium",
}
"""



class GEMXSUMBase(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "GEM/xsum"
    DATASET_NAME = None
ken's avatar
ken committed
35
    SPLIT = None
ken's avatar
ken committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

63
64
65
    def max_generation_length(self):
            return 64

ken's avatar
ken committed
66
67
class GEMXSUM(GEMXSUMBase):
    '''this is for train/validation/test'''
ken's avatar
ken committed
68
    SPLIT = ''
ken's avatar
ken committed
69
70
71

class GEMXSUMChallgeSample(GEMXSUMBase):
    '''this is for challenge_train_sample/challenge_validation_sample'''
ken's avatar
ken committed
72
    SPLIT = 'challenge_sample'
ken's avatar
ken committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    def has_test_docs(self):
        return False
        
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["challenge_train_sample"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["challenge_validation_sample"]

class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
    '''this is for challenge_test_backtranslation'''
ken's avatar
ken committed
92
    SPLIT = 'challenge_test_backtranslation'
ken's avatar
ken committed
93
94
95
96
97
98
99
100
101

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
102
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
103
104
105

class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
    '''this is for challenge_test_bfp_02'''
ken's avatar
ken committed
106
    SPLIT = 'challenge_test_bfp_02'
ken's avatar
ken committed
107
108
109
110
111
112
113
114
115

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
116
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
117
118
119

class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
    '''this is for challenge_test_bfp_05'''
ken's avatar
ken committed
120
    SPLIT = 'challenge_test_bfp_05'
ken's avatar
ken committed
121
122
123
124
125
126
127
128

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
129
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
130
131
132

class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
    '''this is for challenge_test_nopunc'''
ken's avatar
ken committed
133
    SPLIT = 'challenge_test_nopunc'
ken's avatar
ken committed
134
135
136
137
138
139
140
141

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
142
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
143
144
145

class GEMXSUMChallgeTestCovid(GEMXSUMBase):
    '''this is for challenge_test_covid'''
ken's avatar
ken committed
146
    SPLIT = 'challenge_test_covid'
ken's avatar
ken committed
147
148
149
150
151
152
153
154

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
155
            return self.dataset[self.SPLIT]