Commit 875c4ae4 authored by Lysandre's avatar Lysandre
Browse files

Definitive HeisenDistilBug fix

cc @julien-c @@thomwolf
parent f09f42d4
...@@ -112,8 +112,12 @@ class TFModelTesterMixin: ...@@ -112,8 +112,12 @@ class TFModelTesterMixin:
tfo = tf_model(inputs_dict, training=False) tfo = tf_model(inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy() tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy() pt_hidden_states = pto[0].numpy()
pt_hidden_states[np.isnan(tf_hidden_states)] = 0
tf_hidden_states[np.isnan(tf_hidden_states)] = 0 tf_hidden_states[np.isnan(tf_hidden_states)] = 0
pt_hidden_states[np.isnan(pt_hidden_states)] = 0 pt_hidden_states[np.isnan(pt_hidden_states)] = 0
tf_hidden_states[np.isnan(pt_hidden_states)] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed) # Debug info (remove when fixed)
if max_diff >= 2e-2: if max_diff >= 2e-2:
......
...@@ -219,5 +219,5 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -219,5 +219,5 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) # model = DistilBertModesss.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment