Commit b4a3a647 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fix xlnet & transfotests

parent 66c82765
......@@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
24,
24,
0,
29546,
40,
1092,
18,
8,
5854,
7,
1143,
2,
7,
33,
1,
159,
99,
16,
1857,
2,
1,
1009,
4,
......
......@@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
9,
4,
3,
1722,
19,
24,
6348,
61,
977,
176,
1772,
33,
45,
970,
19,
4185,
19,
12943,
4354,
153,
27,
442,
22,
......
......@@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
# father initially slaps him for making such an accusation , Rasputin watches as the
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
expected_output_ids = [
......@@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
24,
24,
0,
29546,
40,
1092,
18,
8,
5854,
7,
1143,
2,
7,
33,
1,
159,
99,
16,
1857,
2,
1,
1009,
4,
......
......@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token
......@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_1"].size()), [])
self.parent.assertListEqual(
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), [])
self.parent.assertListEqual(
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
......@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
outputs = model(
input_ids_1,
......@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
total_loss, mems = outputs
......@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
)
self.parent.assertListEqual(
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
)
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
......@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
......@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9,
4,
3,
1722,
19,
24,
6348,
61,
977,
176,
1772,
33,
45,
970,
19,
4185,
19,
12943,
4354,
153,
27,
442,
22,
......@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# the men are forced to leave the monastery. Rasputin is forced to return to
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment