"model_cards/vscode:/vscode.git/clone" did not exist on "aaab9ab1872d59aa23e68ebccd1ebb884b02f491"
Commit b12541c4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

test ctrl

parent b73dd1a0
......@@ -220,30 +220,30 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl")
input_ids = torch.tensor(
[[11858, 586, 20984, 8]], dtype=torch.long, device=torch_device
) # Legal My neighbor is
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
) # Legal the president is
expected_output_ids = [
11859,
586,
20984,
0,
1611,
8,
13391,
3,
980,
8258,
72,
327,
148,
5,
150,
26449,
2,
53,
29,
226,
3,
780,
49,
19,
348,
469,
3,
980,
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
2595,
48,
20740,
246533,
246533,
19,
30,
5,
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......@@ -209,29 +209,29 @@ class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_ctrl(self):
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
input_ids = tf.convert_to_tensor([[11858, 586, 20984, 8]], dtype=tf.int32)
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
expected_output_ids = [
11859,
586,
20984,
0,
1611,
8,
13391,
3,
980,
8258,
72,
327,
148,
5,
150,
26449,
2,
53,
29,
226,
3,
780,
49,
19,
348,
469,
3,
980,
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
2595,
48,
20740,
246533,
246533,
19,
30,
5,
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment