Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6900dded
Unverified
Commit
6900dded
authored
Aug 12, 2021
by
Patrick von Platen
Committed by
GitHub
Aug 12, 2021
Browse files
[Flax/JAX] Run jitted tests at every commit (#13090)
* up * up * up
parent
773d3860
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
5 deletions
+27
-5
tests/test_modeling_flax_big_bird.py
tests/test_modeling_flax_big_bird.py
+27
-0
tests/test_modeling_flax_clip.py
tests/test_modeling_flax_clip.py
+0
-1
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+0
-2
tests/test_modeling_flax_vit.py
tests/test_modeling_flax_vit.py
+0
-1
tests/test_modeling_flax_wav2vec2.py
tests/test_modeling_flax_wav2vec2.py
+0
-1
No files found.
tests/test_modeling_flax_big_bird.py
View file @
6900dded
...
@@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
...
@@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if
is_flax_available
():
if
is_flax_available
():
import
jax
from
transformers.models.big_bird.modeling_flax_big_bird
import
(
from
transformers.models.big_bird.modeling_flax_big_bird
import
(
FlaxBigBirdForMaskedLM
,
FlaxBigBirdForMaskedLM
,
FlaxBigBirdForMultipleChoice
,
FlaxBigBirdForMultipleChoice
,
...
@@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def
test_attention_outputs
(
self
):
def
test_attention_outputs
(
self
):
if
self
.
test_attn_probs
:
if
self
.
test_attn_probs
:
super
().
test_attention_outputs
()
super
().
test_attention_outputs
()
@
slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
prepared_inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
@
jax
.
jit
def
model_jitted
(
input_ids
,
attention_mask
=
None
,
**
kwargs
):
return
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
**
kwargs
)
with
self
.
subTest
(
"JIT Enabled"
):
jitted_outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
outputs
),
len
(
jitted_outputs
))
for
jitted_output
,
output
in
zip
(
jitted_outputs
,
outputs
):
self
.
assertEqual
(
jitted_output
.
shape
,
output
.
shape
)
tests/test_modeling_flax_clip.py
View file @
6900dded
...
@@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
def
test_hidden_states_output
(
self
):
def
test_hidden_states_output
(
self
):
pass
pass
@
slow
def
test_jit_compilation
(
self
):
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_common.py
View file @
6900dded
...
@@ -34,7 +34,6 @@ from transformers.testing_utils import (
...
@@ -34,7 +34,6 @@ from transformers.testing_utils import (
is_pt_flax_cross_test
,
is_pt_flax_cross_test
,
is_staging_test
,
is_staging_test
,
require_flax
,
require_flax
,
slow
,
)
)
from
transformers.utils
import
logging
from
transformers.utils
import
logging
...
@@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
...
@@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
max_diff
=
(
base_params
[
key
]
-
base_params_from_head
[
key
]).
sum
().
item
()
max_diff
=
(
base_params
[
key
]
-
base_params_from_head
[
key
]).
sum
().
item
()
self
.
assertLessEqual
(
max_diff
,
1e-3
,
msg
=
f
"
{
key
}
not identical"
)
self
.
assertLessEqual
(
max_diff
,
1e-3
,
msg
=
f
"
{
key
}
not identical"
)
@
slow
def
test_jit_compilation
(
self
):
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_vit.py
View file @
6900dded
...
@@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
# We neeed to override this test because ViT expects pixel_values instead of input_ids
# We neeed to override this test because ViT expects pixel_values instead of input_ids
@
slow
def
test_jit_compilation
(
self
):
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
tests/test_modeling_flax_wav2vec2.py
View file @
6900dded
...
@@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names
=
[
"input_values"
,
"attention_mask"
]
expected_arg_names
=
[
"input_values"
,
"attention_mask"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
@
slow
# overwrite because of `input_values`
# overwrite because of `input_values`
def
test_jit_compilation
(
self
):
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment