"vscode:/vscode.git/clone" did not exist on "f55b60b9ee9bc8f7f8ecf04f5d53d0417fbce3d8"
Commit 18a3cef7 authored by thomwolf's avatar thomwolf
Browse files

no nans

parent 1f5d9513
......@@ -88,7 +88,13 @@ class CommonTestCases:
model = model_class.from_pretrained(tmpdirname)
with torch.no_grad():
after_outputs = model(**inputs_dict)
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_initialization(self):
......
......@@ -92,7 +92,13 @@ class TFCommonTestCases:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict)
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
# Make sure we don't have nans
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_pt_tf_model_equivalence(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment