Commit cbb368ca authored by thomwolf's avatar thomwolf
Browse files

distilbert tests

parent 5c00e344
...@@ -96,9 +96,7 @@ class CommonTestCases: ...@@ -96,9 +96,7 @@ class CommonTestCases:
# Make sure we don't have nans # Make sure we don't have nans
out_1 = after_outputs[0].cpu().numpy() out_1 = after_outputs[0].cpu().numpy()
out_2 = outputs[0].cpu().numpy() out_1[np.isnan(out_1)] = 0
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment