Commit ea2600bd authored by Lysandre's avatar Lysandre
Browse files

Absolute definitive HeisenDistilBug solve

cc @julien-c @thomwolf
parent 5c3d441e
......@@ -113,10 +113,13 @@ class TFModelTesterMixin:
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
pt_hidden_states[np.isnan(tf_hidden_states)] = 0
tf_hidden_states[np.isnan(tf_hidden_states)] = 0
pt_hidden_states[np.isnan(pt_hidden_states)] = 0
tf_hidden_states[np.isnan(pt_hidden_states)] = 0
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
......@@ -148,8 +151,14 @@ class TFModelTesterMixin:
tfo = tf_model(inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tfo[np.isnan(tfo)] = 0
pto[np.isnan(pto)] = 0
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 2e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment