"notebooks/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b743cdef713100cc63201e83a78e170c5a68668d"
Unverified Commit f588cf40 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax tests/FlaxBert] make from_pretrained test faster (#15561)

parent 70292409
...@@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: # Only check this for base model, not necessary for all model classes.
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True) # This will also help speed-up tests.
outputs = model(np.ones((1, 1))) model = FlaxBertModel.from_pretrained("bert-base-cased")
self.assertIsNotNone(outputs) outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment