Commit afcea267 authored by jordiclive's avatar jordiclive
Browse files

webnlg

parent 1cd4ec01
...@@ -53,6 +53,7 @@ from . import asdiv ...@@ -53,6 +53,7 @@ from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import hans from . import hans
from . import gem_webnlg
# from . import e2e_nlg_cleaned # from . import e2e_nlg_cleaned
...@@ -108,6 +109,7 @@ TASK_REGISTRY = { ...@@ -108,6 +109,7 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge, "wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"GEM/web_nlg": gem_webnlg.WebNLG,
"drop": drop.DROP, "drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"lambada_cloze": lambada_cloze.LAMBADA_cloze, "lambada_cloze": lambada_cloze.LAMBADA_cloze,
......
from lm_eval.base import PromptSourceTask
class WebNLG(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/web_nlg"
DATASET_NAME = "en"
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
if self.has_training_docs():
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
def stopping_criteria(self):
return '*'
def max_generation_length(self):
return 250
...@@ -56,7 +56,7 @@ def main(): ...@@ -56,7 +56,7 @@ def main():
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = description_dict[task_name] if description_dict and task_name in description_dict else ""
task_name = task_name.replace('/','_')
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment