Commit e645dcbb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add special tokens to pretrain configs of respective lm head models

parent e693cd1e
...@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 50256, "eos_token_id": 50256}
class GPT2ModelLanguageGenerationTest(unittest.TestCase): class GPT2ModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_gpt2(self): def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2")
...@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute too. It likes to rub on me and is good for me (the dog ] # The dog is cute too. It likes to rub on me and is good for me (the dog
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to ] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"eos_token_id": 0}
class TransfoXLModelLanguageGenerationTest(unittest.TestCase): class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_transfo_xl_wt103(self): def test_lm_generate_transfo_xl_wt103(self):
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103") model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
...@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200) output_ids = model.generate(input_ids, max_length=200)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 0, "pad_token_id": 2}
class XLMModelLanguageGenerationTest(unittest.TestCase): class XLMModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_xlm_mlm_en_2048(self): def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048") model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
...@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase): ...@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation. ] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
class XLNetModelLanguageGenerationTest(unittest.TestCase): class XLNetModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_xlnet_base_cased(self): def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased") model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
...@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# Since, however, he has had difficulty walking with Maria # Since, however, he has had difficulty walking with Maria
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids, max_length=200)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
max_length=200,
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment