Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6900dded
Unverified
Commit
6900dded
authored
Aug 12, 2021
by
Patrick von Platen
Committed by
GitHub
Aug 12, 2021
Browse files
[Flax/JAX] Run jitted tests at every commit (#13090)
* up * up * up
parent
773d3860
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
5 deletions
+27
-5
tests/test_modeling_flax_big_bird.py
tests/test_modeling_flax_big_bird.py
+27
-0
tests/test_modeling_flax_clip.py
tests/test_modeling_flax_clip.py
+0
-1
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+0
-2
tests/test_modeling_flax_vit.py
tests/test_modeling_flax_vit.py
+0
-1
tests/test_modeling_flax_wav2vec2.py
tests/test_modeling_flax_wav2vec2.py
+0
-1
No files found.
tests/test_modeling_flax_big_bird.py
View file @
6900dded
...
...
@@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if
is_flax_available
():
import
jax
from
transformers.models.big_bird.modeling_flax_big_bird
import
(
FlaxBigBirdForMaskedLM
,
FlaxBigBirdForMultipleChoice
,
...
...
@@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def
test_attention_outputs
(
self
):
if
self
.
test_attn_probs
:
super
().
test_attention_outputs
()
@
slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
prepared_inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
@
jax
.
jit
def
model_jitted
(
input_ids
,
attention_mask
=
None
,
**
kwargs
):
return
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
kwargs
)
with
self
.
subTest
(
"JIT Enabled"
):
jitted_outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
outputs
),
len
(
jitted_outputs
))
for
jitted_output
,
output
in
zip
(
jitted_outputs
,
outputs
):
self
.
assertEqual
(
jitted_output
.
shape
,
output
.
shape
)
tests/test_modeling_flax_clip.py
View file @
6900dded
...
...
@@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
def
test_hidden_states_output
(
self
):
pass
@
slow
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_common.py
View file @
6900dded
...
...
@@ -34,7 +34,6 @@ from transformers.testing_utils import (
is_pt_flax_cross_test
,
is_staging_test
,
require_flax
,
slow
,
)
from
transformers.utils
import
logging
...
...
@@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
max_diff
=
(
base_params
[
key
]
-
base_params_from_head
[
key
]).
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-3
,
msg
=
f
"
{
key
}
not identical"
)
@
slow
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_vit.py
View file @
6900dded
...
...
@@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
# We neeed to override this test because ViT expects pixel_values instead of input_ids
@
slow
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_wav2vec2.py
View file @
6900dded
...
...
@@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names
=
[
"input_values"
,
"attention_mask"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
@
slow
# overwrite because of `input_values`
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment