Commit b12541c4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

test ctrl

parent b73dd1a0
...@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase): ...@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
def test_lm_generate_ctrl(self): def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl") model = CTRLLMHeadModel.from_pretrained("ctrl")
input_ids = torch.tensor( input_ids = torch.tensor(
[[11858, 586, 20984, 8]], dtype=torch.long, device=torch_device [[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
) # Legal My neighbor is ) # Legal the president is
expected_output_ids = [ expected_output_ids = [
11859, 11859,
586, 0,
20984, 1611,
8, 8,
13391, 5,
3, 150,
980, 26449,
8258,
72,
327,
148,
2, 2,
53, 19,
29, 348,
226, 469,
3,
780,
49,
3, 3,
980, 2595,
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay 48,
20740,
246533,
246533,
19,
30,
5,
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
...@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase): ...@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
@slow @slow
def test_lm_generate_ctrl(self): def test_lm_generate_ctrl(self):
model = TFCTRLLMHeadModel.from_pretrained("ctrl") model = TFCTRLLMHeadModel.from_pretrained("ctrl")
input_ids = tf.convert_to_tensor([[11858, 586, 20984, 8]], dtype=tf.int32) input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
expected_output_ids = [ expected_output_ids = [
11859, 11859,
586, 0,
20984, 1611,
8, 8,
13391, 5,
3, 150,
980, 26449,
8258,
72,
327,
148,
2, 2,
53, 19,
29, 348,
226, 469,
3,
780,
49,
3, 3,
980, 2595,
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay 48,
20740,
246533,
246533,
19,
30,
5,
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment