"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6fe8a693ebbfa6e70b880f7c24e0cf524be6fb25"
Commit e645dcbb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add special tokens to pretrain configs of respective lm head models

parent e693cd1e
...@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 50256, "eos_token_id": 50256}
class GPT2ModelLanguageGenerationTest(unittest.TestCase): class GPT2ModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_gpt2(self): def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2")
...@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute too. It likes to rub on me and is good for me (the dog ] # The dog is cute too. It likes to rub on me and is good for me (the dog
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to ] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"eos_token_id": 0}
class TransfoXLModelLanguageGenerationTest(unittest.TestCase): class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_transfo_xl_wt103(self): def test_lm_generate_transfo_xl_wt103(self):
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103") model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
...@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200) output_ids = model.generate(input_ids, max_length=200)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 0, "pad_token_id": 2}
class XLMModelLanguageGenerationTest(unittest.TestCase): class XLMModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_xlm_mlm_en_2048(self): def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048") model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
...@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase): ...@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation. ] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
class XLNetModelLanguageGenerationTest(unittest.TestCase): class XLNetModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow @slow
def test_lm_generate_xlnet_base_cased(self): def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased") model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
...@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# Since, however, he has had difficulty walking with Maria # Since, however, he has had difficulty walking with Maria
torch.manual_seed(0) torch.manual_seed(0)
output_ids = model.generate( output_ids = model.generate(input_ids, max_length=200)
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
max_length=200,
)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment