Commit b4a3a647 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fix xlnet & transfotests

parent 66c82765
...@@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -519,20 +519,10 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
24, 24,
24, 24,
0, 0,
29546, 33,
40,
1092,
18,
8,
5854,
7,
1143,
2,
7,
1, 1,
159, 1857,
99, 2,
16,
1, 1,
1009, 1009,
4, 4,
......
...@@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -760,20 +760,10 @@ class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
9, 9,
4, 4,
3, 3,
1722,
19,
24,
6348,
61,
977,
176,
1772,
33,
45,
970,
19,
4185,
19, 19,
12943,
4354,
153,
27, 27,
442, 442,
22, 22,
......
...@@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -376,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
# father initially slaps him for making such an accusation , Rasputin watches as the # father initially slaps him for making such an accusation , Rasputin watches as the
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of # man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous , # the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
# with people , even a bishop , begging for his blessing . <eod> </s> <eos> # with people , even a bishop , begging for his blessing . <eod> </s> <eos>
expected_output_ids = [ expected_output_ids = [
...@@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): ...@@ -520,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
24, 24,
24, 24,
0, 0,
29546, 33,
40,
1092,
18,
8,
5854,
7,
1143,
2,
7,
1, 1,
159, 1857,
99, 2,
16,
1, 1,
1009, 1009,
4, 4,
......
...@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros( perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
) )
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros( target_mapping = torch.zeros(
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
) )
target_mapping[:, 0, -1] = 1.0 # predict last token target_mapping[:, 0, -1] = 1.0 # predict last token
...@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size] list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_1"].size()), []) self.parent.assertListEqual(list(result["loss_1"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss_2"].size()), []) self.parent.assertListEqual(list(result["loss_2"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size] list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
...@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
outputs = model(input_ids_1) outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
outputs = model( outputs = model(
input_ids_1, input_ids_1,
...@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
total_loss, mems = outputs total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels) outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
total_loss, mems = outputs total_loss, mems = outputs
...@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top] list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top] list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_log_probs"].size()), list(result["end_top_log_probs"].size()),
...@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size] list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size] list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
...@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -859,20 +859,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
9, 9,
4, 4,
3, 3,
1722,
19,
24,
6348,
61,
977,
176,
1772,
33,
45,
970,
19,
4185,
19, 19,
12943,
4354,
153,
27, 27,
442, 442,
22, 22,
...@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): ...@@ -922,5 +912,4 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
# the men are forced to leave the monastery. Rasputin is forced to return to # the men are forced to leave the monastery. Rasputin is forced to return to
output_ids = model.generate(input_ids, max_length=200, do_sample=False) output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment