"...git@developer.sourcefind.cn:modelzoo/speecht5_pytorch.git" did not exist on "12c906397f40c1533b263d056e89faaf4dc3944f"
Commit 5de1517d authored by Tim Rault's avatar Tim Rault
Browse files

WIP modeling_test_pytorch.py

parent 1ba5b58c
......@@ -94,11 +94,10 @@ class BertModelTest(unittest.TestCase):
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output, embedding_output, sequence_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
outputs = {
"embedding_output": embedding_output,
"sequence_output": sequence_output,
"sequence_output": all_encoder_layers[-1],
"pooled_output": pooled_output,
"all_encoder_layers": all_encoder_layers,
}
......@@ -106,13 +105,10 @@ class BertModelTest(unittest.TestCase):
def check_output(self, result):
self.parent.assertListEqual(
result["embedding_output"].shape,
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
result["sequence_output"].shape,
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(result["pooled_output"].shape, [self.batch_size, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self))
......@@ -144,6 +140,7 @@ class BertModelTest(unittest.TestCase):
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
return torch.tensor(data=values, dtype=torch.int32)
def assert_all_tensors_reachable(self, sess, outputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment