Commit 8c83a821 authored by ken's avatar ken
Browse files

Add SPLIT

parent 6a2b94b2
...@@ -31,6 +31,7 @@ class GEMXSUMBase(PromptSourceTask): ...@@ -31,6 +31,7 @@ class GEMXSUMBase(PromptSourceTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "GEM/xsum" DATASET_PATH = "GEM/xsum"
DATASET_NAME = None DATASET_NAME = None
SPLIT = None
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -62,9 +63,11 @@ class GEMXSUMBase(PromptSourceTask): ...@@ -62,9 +63,11 @@ class GEMXSUMBase(PromptSourceTask):
class GEMXSUM(GEMXSUMBase): class GEMXSUM(GEMXSUMBase):
'''this is for train/validation/test''' '''this is for train/validation/test'''
SPLIT = ''
class GEMXSUMChallgeSample(GEMXSUMBase): class GEMXSUMChallgeSample(GEMXSUMBase):
'''this is for challenge_train_sample/challenge_validation_sample''' '''this is for challenge_train_sample/challenge_validation_sample'''
SPLIT = 'challenge_sample'
def has_test_docs(self): def has_test_docs(self):
return False return False
...@@ -84,6 +87,7 @@ class GEMXSUMChallgeSample(GEMXSUMBase): ...@@ -84,6 +87,7 @@ class GEMXSUMChallgeSample(GEMXSUMBase):
class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase): class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
'''this is for challenge_test_backtranslation''' '''this is for challenge_test_backtranslation'''
SPLIT = 'challenge_test_backtranslation'
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -93,10 +97,11 @@ class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase): ...@@ -93,10 +97,11 @@ class GEMXSUMChallgeTestBacktranslation(GEMXSUMBase):
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["challenge_test_backtranslation"] return self.dataset[SPLIT]
class GEMXSUMChallgeTestBFP02(GEMXSUMBase): class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
'''this is for challenge_test_bfp_02''' '''this is for challenge_test_bfp_02'''
SPLIT = 'challenge_test_bfp_02'
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -106,10 +111,11 @@ class GEMXSUMChallgeTestBFP02(GEMXSUMBase): ...@@ -106,10 +111,11 @@ class GEMXSUMChallgeTestBFP02(GEMXSUMBase):
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["challenge_test_bfp_02"] return self.dataset[SPLIT]
class GEMXSUMChallgeTestBFP05(GEMXSUMBase): class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
'''this is for challenge_test_bfp_05''' '''this is for challenge_test_bfp_05'''
SPLIT = 'challenge_test_bfp_05'
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -118,10 +124,11 @@ class GEMXSUMChallgeTestBFP05(GEMXSUMBase): ...@@ -118,10 +124,11 @@ class GEMXSUMChallgeTestBFP05(GEMXSUMBase):
return False return False
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["challenge_test_bfp_05"] return self.dataset[SPLIT]
class GEMXSUMChallgeTestNopunc(GEMXSUMBase): class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
'''this is for challenge_test_nopunc''' '''this is for challenge_test_nopunc'''
SPLIT = 'challenge_test_nopunc'
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -130,10 +137,11 @@ class GEMXSUMChallgeTestNopunc(GEMXSUMBase): ...@@ -130,10 +137,11 @@ class GEMXSUMChallgeTestNopunc(GEMXSUMBase):
return False return False
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["challenge_test_nopunc"] return self.dataset[SPLIT]
class GEMXSUMChallgeTestCovid(GEMXSUMBase): class GEMXSUMChallgeTestCovid(GEMXSUMBase):
'''this is for challenge_test_covid''' '''this is for challenge_test_covid'''
SPLIT = 'challenge_test_covid'
def has_training_docs(self): def has_training_docs(self):
return False return False
...@@ -142,4 +150,4 @@ class GEMXSUMChallgeTestCovid(GEMXSUMBase): ...@@ -142,4 +150,4 @@ class GEMXSUMChallgeTestCovid(GEMXSUMBase):
return False return False
def test_docs(self): def test_docs(self):
if self.has_test_docs(): if self.has_test_docs():
return self.dataset["challenge_test_covid"] return self.dataset[SPLIT]
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment