Commit fbd02d46 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fixed all tests, still need to check ctrl tf and pt and xlm tf

parent b4a3a647
...@@ -246,4 +246,4 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase): ...@@ -246,4 +246,4 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay ] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
...@@ -356,7 +356,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -356,7 +356,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
3290, 3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow @slow
def test_lm_generate_distilgpt2(self): def test_lm_generate_distilgpt2(self):
......
...@@ -269,4 +269,4 @@ class TFOPENAIGPTModelLanguageGenerationTest(unittest.TestCase): ...@@ -269,4 +269,4 @@ class TFOPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
] # the president is a very good man. " \n " i\'m sure he is, " said the ] # the president is a very good man. " \n " i\'m sure he is, " said the
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
...@@ -573,4 +573,4 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -573,4 +573,4 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
# TODO: add this test when trasnfo-xl-lmhead is implemented # TODO: add this test when trasnfo-xl-lmhead is implemented
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
model.generate(input_ids, max_length=200, do_sample=False) model.generate(input_ids, max_length=200, do_sample=False)
# self.assertListEqual(output_ids[0].tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented # self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
...@@ -317,28 +317,29 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase): ...@@ -317,28 +317,29 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_xlm_mlm_en_2048(self): def test_lm_generate_xlm_mlm_en_2048(self):
model = TFXLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048") model = TFXLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
input_ids = tf.convert_to_tensor([[1, 14, 2232, 26, 1]], dtype=tf.int32) # the dog is cute input_ids = tf.convert_to_tensor([[14, 447]], dtype=tf.int32) # the president
expected_output_ids = [ expected_output_ids = [
1,
14, 14,
2232, 447,
26, 14,
1, 447,
567, 14,
26, 447,
32, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation. 447,
] # the president the president the president the president the president the president the president the president the president the president
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids) output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False) self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)
...@@ -814,4 +814,4 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -814,4 +814,4 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
output_ids = model.generate(input_ids, max_length=200, do_sample=False) output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
...@@ -403,28 +403,29 @@ class XLMModelLanguageGenerationTest(unittest.TestCase): ...@@ -403,28 +403,29 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_xlm_mlm_en_2048(self): def test_lm_generate_xlm_mlm_en_2048(self):
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048") model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
input_ids = torch.tensor([[1, 14, 2232, 26, 1]], dtype=torch.long, device=torch_device) # The dog is cute input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device) # the president
expected_output_ids = [ expected_output_ids = [
1,
14, 14,
2232, 447,
26, 14,
1, 447,
567, 14,
26, 447,
32, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
149, 447,
149, 14,
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation. 447,
output_ids = model.generate(input_ids) ] # the president the president the president the president the president the president the president the president the president the president
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False) # TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment