Commit 795b7912 authored by Zdenek Kasner's avatar Zdenek Kasner
Browse files

e2e_nlg_cleaned: Add custom code for filtering out invalid examples

parent 3eb3d8b7
...@@ -8,6 +8,7 @@ The dataset contains MR with restaurant attributes and corresponding description ...@@ -8,6 +8,7 @@ The dataset contains MR with restaurant attributes and corresponding description
Homepage: https://github.com/tuetschek/e2e-cleaning Homepage: https://github.com/tuetschek/e2e-cleaning
""" """
from lm_eval.base import PromptSourceTask, rf from lm_eval.base import PromptSourceTask, rf
from lm_eval import metrics
_CITATION = """ _CITATION = """
@inproceedings{dusek-etal-2019-semantic, @inproceedings{dusek-etal-2019-semantic,
...@@ -69,11 +70,36 @@ class E2E_NLG_Cleaned(PromptSourceTask): ...@@ -69,11 +70,36 @@ class E2E_NLG_Cleaned(PromptSourceTask):
return self.prompt.name.endswith("_qa") return self.prompt.name.endswith("_qa")
def doc_to_text(self, doc) -> str: def doc_to_text(self, doc) -> str:
# if the response is not defined in PS, the text will be an empty string # if the response is not defined in PS, the text will be a single-element list containing an empty string
text = self.prompt.apply(doc)[0] text = self.prompt.apply(doc)[0]
return text return text
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests = []
# NOTE: In the future, target will be a list of strings.
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
# Skip examples for which the templates are not applicable
if ctx != "":
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
return _requests
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [metric_score] -> float} :returns: {str: [metric_score] -> float}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment