"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c51dc4f92755c67a83f3fc8a0bd6b3e64df199e4"
Unverified Commit 0a42b61e authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_save_load` for `TFViTMAEModelTest` (#19040)



* Fix test_save_load for TFViTMAEModelTest
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 30a28f52
...@@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -375,7 +375,6 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
@slow
def test_save_load(self): def test_save_load(self):
# make mask reproducible # make mask reproducible
np.random.seed(2) np.random.seed(2)
...@@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -398,9 +397,8 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True) model.save_pretrained(tmpdirname, saved_model=False)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") model = model_class.from_pretrained(tmpdirname)
model = tf.keras.models.load_model(saved_model_dir)
after_outputs = model(model_input, noise=noise) after_outputs = model(model_input, noise=noise)
if model_class.__name__ == "TFViTMAEModel": if model_class.__name__ == "TFViTMAEModel":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment