Unverified Commit f588cf40 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax tests/FlaxBert] make from_pretrained test faster (#15561)

parent 70292409
...@@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -141,7 +141,8 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: # Only check this for base model, not necessary for all model classes.
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True) # This will also help speed-up tests.
model = FlaxBertModel.from_pretrained("bert-base-cased")
outputs = model(np.ones((1, 1))) outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment