Commit a5df980c authored by thomwolf's avatar thomwolf
Browse files

updating distilbert test

parent 7c3a15ac
......@@ -121,7 +121,12 @@ class CommonTestCases:
model.to(torch_device)
model.eval()
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0]
self.assertEqual(first.ne(second).sum().item(), 0)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -294,7 +294,12 @@ class TFCommonTestCases:
for model_class in self.all_model_classes:
model = model_class(config)
first, second = model(inputs_dict, training=False)[0], model(inputs_dict, training=False)[0]
self.assertTrue(tf.math.equal(first, second).numpy().all())
out_1 = first.numpy()
out_2 = second.numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def _get_embeds(self, wte, input_ids):
# ^^ In our TF models, the input_embeddings can take slightly different forms,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment