Unverified Commit 5e428b71 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[BigBirdFlaxTests] Make tests slow (#17658)

* [BigBirdFlaxTests] Make tests slow

* up

* correct black with new version
parent 3114df41
......@@ -40,17 +40,17 @@ class FlaxBigBirdModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
batch_size=2,
seq_length=56,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_size=4,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=7,
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
......@@ -62,7 +62,7 @@ class FlaxBigBirdModelTester(unittest.TestCase):
attention_type="block_sparse",
use_bias=True,
rescale_embeddings=False,
block_size=4,
block_size=2,
num_random_blocks=3,
):
self.parent = parent
......@@ -156,10 +156,30 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = FlaxBigBirdModelTester(self)
@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_from_pretrained_with_no_automatic_init(self):
super().test_from_pretrained_with_no_automatic_init()
@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_no_automatic_init(self):
super().test_no_automatic_init()
@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_hidden_states_output(self):
super().test_hidden_states_output()
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True)
model = model_class_name.from_pretrained("google/bigbird-roberta-base")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment