Commit 9abb1814 authored by Shashi456's avatar Shashi456
Browse files

Update gem_mlsum

parent 7e25c240
......@@ -53,9 +53,10 @@ from . import asdiv
from . import gsm8k
from . import storycloze
from . import hans
from . import mlsum
from . import gem_webnlg
from . import gem_mlsum
# from . import e2e_nlg_cleaned
########################################
......@@ -288,8 +289,16 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
"mlsum_es":mlsum.MLSUMEs,
"mlsum_de":mlsum.MLSUMDe,
#GEM/mlsum
"mlsum_es":gem_mlsum.MLSUMEs,
"mlsum_de":gem_mlsum.MLSUMDe,
"mlsum_es_covid_challenge_set":gem_mlsum.GEMMLSUMEsChallgeTestCovid,
"mlsum_de_covid_challenge_set":gem_mlsum.GEMMLSUMDeChallgeTestCovid,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
......
......@@ -22,7 +22,7 @@ _CITATION = """
"""
class MLSUMEs(PromptSourceTask):
class GEMMLSUMEsBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "es"
......@@ -53,12 +53,24 @@ class MLSUMEs(PromptSourceTask):
def stopping_criteria(self):
return "."
def max_generation_length(self):
return 120
class GEMMLSUMEs(GEMMLSUMEsBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMEsChallgeTestCovid(GEMMLSUMEsBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
class MLSUMDe(PromptSourceTask):
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMMLSUMDeBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "de"
......@@ -89,5 +101,19 @@ class MLSUMDe(PromptSourceTask):
def stopping_criteria(self):
return "."
def max_generation_length(self):
return 120
class GEMMLSUMDe(GEMMLSUMDeBase):
'''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMDeChallgeTestCovid(GEMMLSUMDeBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment