gem_xsum.py 4.84 KB
Newer Older
ken's avatar
ken committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Don’t Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization
https://arxiv.org/pdf/1808.08745.pdf

The dataset is for the task of abstractive summarization in its extreme form, its about summarizing a document in a single sentence. It introduces extreme summarization, a new single-document summarization task which does not favor extractive strategies and calls for an abstractive modeling approach. The idea is to create a short, one-sentence news summary answering the question "What is the article about?". 

This particularly uses the dataset that is part of the GEM benchmark
Homepage: https://github.com/EdinburghNLP/XSum
The GEM Benchmark: Natural Language Generation, its Evaluation and Metrics
https://arxiv.org/pdf/2102.01672v3.pdf
Write a Short Description of the task.
Homepage: https://gem-benchmark.com/data_cards/XSum
"""
from lm_eval.base import PromptSourceTask
from lm_eval.base import Task, rf


_CITATION = """
@InProceedings{xsum-emnlp,
  author =      "Shashi Narayan and Shay B. Cohen and Mirella Lapata",
  title =       "Don't Give Me the Details, Just the Summary! {T}opic-Aware Convolutional Neural Networks for Extreme Summarization",
  booktitle =   "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing ",
  year =        "2018",
  address =     "Brussels, Belgium",
}
"""



class GEMXSUMBase(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "GEM/xsum"
    DATASET_NAME = None
ken's avatar
ken committed
34
    SPLIT = None
ken's avatar
ken committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

class GEMXSUM(GEMXSUMBase):
    '''this is for train/validation/test'''
ken's avatar
ken committed
64
    SPLIT = ''
ken's avatar
ken committed
65
66
67

class GEMXSUMChallgeSample(GEMXSUMBase):
    '''this is for challenge_train_sample/challenge_validation_sample'''
ken's avatar
ken committed
68
    SPLIT = 'challenge_sample'
ken's avatar
ken committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def has_test_docs(self):
        return False
        
    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["challenge_train_sample"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["challenge_validation_sample"]

class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
    '''this is for challenge_test_backtranslation'''
ken's avatar
ken committed
88
    SPLIT = 'challenge_test_backtranslation'
ken's avatar
ken committed
89
90
91
92
93
94
95
96
97

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
98
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
99
100
101

class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
    '''this is for challenge_test_bfp_02'''
ken's avatar
ken committed
102
    SPLIT = 'challenge_test_bfp_02'
ken's avatar
ken committed
103
104
105
106
107
108
109
110
111

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
112
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
113
114
115

class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
    '''this is for challenge_test_bfp_05'''
ken's avatar
ken committed
116
    SPLIT = 'challenge_test_bfp_05'
ken's avatar
ken committed
117
118
119
120
121
122
123
124

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
125
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
126
127
128

class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
    '''this is for challenge_test_nopunc'''
ken's avatar
ken committed
129
    SPLIT = 'challenge_test_nopunc'
ken's avatar
ken committed
130
131
132
133
134
135
136
137

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
138
            return self.dataset[self.SPLIT]
ken's avatar
ken committed
139
140
141

class GEMXSUMChallgeTestCovid(GEMXSUMBase):
    '''this is for challenge_test_covid'''
ken's avatar
ken committed
142
    SPLIT = 'challenge_test_covid'
ken's avatar
ken committed
143
144
145
146
147
148
149
150

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False
    def test_docs(self):
        if self.has_test_docs():
ken's avatar
ken committed
151
            return self.dataset[self.SPLIT]