Unverified Commit e983da0e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

cleanup tf unittests: part 2 (#6260)

* cleanup torch unittests: part 2

* remove trailing comma added by isort, and which breaks flake

* one more comma

* revert odd balls

* part 3: odd cases

* more ["key"] -> .key refactoring

* .numpy() is not needed

* more unncessary .numpy() removed

* more simplification
parent bc820476
...@@ -192,8 +192,8 @@ class XLNetModelTester: ...@@ -192,8 +192,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), [mem.shape for mem in result.mems],
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_xlnet_model_use_cache( def create_and_check_xlnet_model_use_cache(
...@@ -305,22 +305,22 @@ class XLNetModelTester: ...@@ -305,22 +305,22 @@ class XLNetModelTester:
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"]) result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
self.parent.assertEqual(result1.loss.shape, ()) self.parent.assertEqual(result1.loss.shape, ())
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result1["mems"]), [mem.shape for mem in result1.mems],
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
self.parent.assertEqual(result2.loss.shape, ()) self.parent.assertEqual(result2.loss.shape, ())
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result2["mems"]), [mem.shape for mem in result2.mems],
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_xlnet_qa( def create_and_check_xlnet_qa(
...@@ -378,8 +378,8 @@ class XLNetModelTester: ...@@ -378,8 +378,8 @@ class XLNetModelTester:
) )
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,)) self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), [mem.shape for mem in result.mems],
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_xlnet_token_classif( def create_and_check_xlnet_token_classif(
...@@ -407,8 +407,8 @@ class XLNetModelTester: ...@@ -407,8 +407,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), [mem.shape for mem in result.mems],
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_and_check_xlnet_sequence_classif( def create_and_check_xlnet_sequence_classif(
...@@ -436,8 +436,8 @@ class XLNetModelTester: ...@@ -436,8 +436,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), [mem.shape for mem in result.mems],
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment