"docs/source/vscode:/vscode.git/clone" did not exist on "387217bd3e9a564cd84d4c4cc3c2f25ce30966bc"
Commit 1761d209 authored by LysandreJik's avatar LysandreJik
Browse files

Check to see if the models have the same results when in eval mode (pt) or when training=False (tf)

parent 128bdd4c
...@@ -68,6 +68,16 @@ class CommonTestCases: ...@@ -68,6 +68,16 @@ class CommonTestCases:
self.assertIn(param.data.mean().item(), [0.0, 1.0], self.assertIn(param.data.mean().item(), [0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
first, second = model(inputs_dict["input_ids"])[0], model(inputs_dict["input_ids"])[0]
self.assertEqual(first.ne(second).sum().item(), 0)
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -298,6 +298,14 @@ class TFCommonTestCases: ...@@ -298,6 +298,14 @@ class TFCommonTestCases:
# self.assertGreater(len(params_not_tied), len(params_tied)) # self.assertGreater(len(params_not_tied), len(params_tied))
# self.assertEqual(len(params_tied_2), len(params_tied)) # self.assertEqual(len(params_tied_2), len(params_tied))
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
self.assertTrue(tf.math.equal(first, second).numpy().all())
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment