e2e_nlg_cleaned.py 2.95 KB
Newer Older
1
2
3
4
"""
Semantic Noise Matters for Neural Natural Language Generation
http://arxiv.org/abs/1911.03905

5
6
A cleaned version of the dataset from the E2E NLG Challenge.
The dataset contains MR with restaurant attributes and corresponding descriptions.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

Homepage: https://github.com/tuetschek/e2e-cleaning
"""
from lm_eval.base import PromptSourceTask, rf

_CITATION = """
@inproceedings{dusek-etal-2019-semantic,
    title = "Semantic Noise Matters for Neural Natural Language Generation",
    author = "Du{\v{s}}ek, Ond{\v{r}}ej  and
      Howcroft, David M.  and
      Rieser, Verena",
    booktitle = "Proceedings of the 12th International Conference on Natural Language Generation",
    year = "2019",
    address = "Tokyo, Japan",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/W19-8652",
    doi = "10.18653/v1/W19-8652",
    pages = "421--426",
}
"""

# Work in progress
class E2E_NLG_Cleaned(PromptSourceTask):
    VERSION = 0
    DATASET_PATH = "e2e_nlg_cleaned"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            # We cache training documents in `self._training_docs` for faster
            # few-shot processing. If the data is too large to fit in memory,
            # return the training data as a generator instead of a list.
            if self._training_docs is None:
                self._training_docs = list(self.dataset["train"])
            return self._training_docs

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

    def stopping_criteria(self):
        return '\n'

63
64
65
66
67
68
69
70
    def max_generation_length(self):
        # TODO check
        return 512

    def invalid_doc_for_prompt(self, doc) -> bool:
        """The QA prompts are not applicable to all the examples, we want to filter these out."""
        return self.prompt.name.endswith("_qa")

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    def doc_to_text(self, doc) -> str:
        # if the response is not defined in PS, the text will be an empty string
        text = self.prompt.apply(doc)[0]

        return text

    def aggregation(self):
        """
        :returns: {str: [metric_score] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metric scores
        """
        return {
            "bleu": metrics.bleu,
            "rouge": metrics.rouge,
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "bleu": True,
            "rouge": True,
        }