"vscode:/vscode.git/clone" did not exist on "a565d720bcde6e5c77d0993a4efc30f3e7891350"
Commit e645dcbb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add special tokens to pretrain configs of respective lm head models

parent e693cd1e
......@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 50256, "eos_token_id": 50256}
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
......@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute too. It likes to rub on me and is good for me (the dog
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......@@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"eos_token_id": 0}
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_transfo_xl_wt103(self):
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
......@@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
torch.manual_seed(0)
output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200)
output_ids = model.generate(input_ids, max_length=200)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......@@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 0, "pad_token_id": 2}
class XLMModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
......@@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
class XLNetModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_xlnet_base_cased(self):
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
......@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# Since, however, he has had difficulty walking with Maria
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
pad_token_id=self.special_tokens["pad_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
max_length=200,
)
output_ids = model.generate(input_ids, max_length=200)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment