Commit 9abb1814 authored by Shashi456's avatar Shashi456
Browse files

Update gem_mlsum

parent 7e25c240
...@@ -53,9 +53,10 @@ from . import asdiv ...@@ -53,9 +53,10 @@ from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import hans from . import hans
from . import mlsum
from . import gem_webnlg from . import gem_webnlg
from . import gem_mlsum
# from . import e2e_nlg_cleaned # from . import e2e_nlg_cleaned
######################################## ########################################
...@@ -288,8 +289,16 @@ TASK_REGISTRY = { ...@@ -288,8 +289,16 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance, "blimp_wh_vs_that_no_gap_long_distance": blimp.BlimpWhVsThatNoGapLongDistance,
"blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap, "blimp_wh_vs_that_with_gap": blimp.BlimpWhVsThatWithGap,
"blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance, "blimp_wh_vs_that_with_gap_long_distance": blimp.BlimpWhVsThatWithGapLongDistance,
"mlsum_es":mlsum.MLSUMEs,
"mlsum_de":mlsum.MLSUMDe,
#GEM/mlsum
"mlsum_es":gem_mlsum.MLSUMEs,
"mlsum_de":gem_mlsum.MLSUMDe,
"mlsum_es_covid_challenge_set":gem_mlsum.GEMMLSUMEsChallgeTestCovid,
"mlsum_de_covid_challenge_set":gem_mlsum.GEMMLSUMDeChallgeTestCovid,
# Requires manual download of data. # Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016, # "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018, # "storycloze_2018": storycloze.StoryCloze2018,
......
...@@ -22,7 +22,7 @@ _CITATION = """ ...@@ -22,7 +22,7 @@ _CITATION = """
""" """
class MLSUMEs(PromptSourceTask): class GEMMLSUMEsBase(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "GEM/mlsum" DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "es" DATASET_NAME = "es"
...@@ -53,12 +53,24 @@ class MLSUMEs(PromptSourceTask): ...@@ -53,12 +53,24 @@ class MLSUMEs(PromptSourceTask):
def stopping_criteria(self): def stopping_criteria(self):
return "." return "."
def max_generation_length(self): class GEMMLSUMEs(GEMMLSUMEsBase):
return 120 '''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMEsChallgeTestCovid(GEMMLSUMEsBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
class MLSUMDe(PromptSourceTask): def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
class GEMMLSUMDeBase(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "GEM/mlsum" DATASET_PATH = "GEM/mlsum"
DATASET_NAME = "de" DATASET_NAME = "de"
...@@ -89,5 +101,19 @@ class MLSUMDe(PromptSourceTask): ...@@ -89,5 +101,19 @@ class MLSUMDe(PromptSourceTask):
def stopping_criteria(self): def stopping_criteria(self):
return "." return "."
def max_generation_length(self): class GEMMLSUMDe(GEMMLSUMDeBase):
return 120 '''this is for train/validation/test'''
SPLIT = ''
class GEMMLSUMDeChallgeTestCovid(GEMMLSUMDeBase):
'''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def test_docs(self):
if self.has_test_docs():
return self.dataset[self.SPLIT]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment