Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
007be9e4
Unverified
Commit
007be9e4
authored
Jun 14, 2021
by
Patrick von Platen
Committed by
GitHub
Jun 14, 2021
Browse files
[Flax] Fix flax pt equivalence tests (#12154)
* fix_torch_device_generate_test * remove @ * upload
parent
d438eee0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
8 deletions
+4
-8
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+4
-8
No files found.
tests/test_modeling_flax_common.py
View file @
007be9e4
...
@@ -181,7 +181,7 @@ class FlaxModelTesterMixin:
...
@@ -181,7 +181,7 @@ class FlaxModelTesterMixin:
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
1
e-
3
)
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
4
e-
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
)
pt_model
.
save_pretrained
(
tmpdirname
)
...
@@ -192,10 +192,7 @@ class FlaxModelTesterMixin:
...
@@ -192,10 +192,7 @@ class FlaxModelTesterMixin:
len
(
fx_outputs_loaded
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
len
(
fx_outputs_loaded
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
)
for
fx_output_loaded
,
pt_output
in
zip
(
fx_outputs_loaded
,
pt_outputs
):
for
fx_output_loaded
,
pt_output
in
zip
(
fx_outputs_loaded
,
pt_outputs
):
if
not
isinstance
(
self
.
assert_almost_equals
(
fx_output_loaded
,
pt_output
.
numpy
(),
4e-2
)
fx_output_loaded
,
tuple
):
# TODO(Patrick, Daniel) - let's discard use_cache for now
self
.
assert_almost_equals
(
fx_output_loaded
,
pt_output
.
numpy
(),
1e-3
)
@
is_pt_flax_cross_test
@
is_pt_flax_cross_test
def
test_equivalence_flax_to_pt
(
self
):
def
test_equivalence_flax_to_pt
(
self
):
...
@@ -229,7 +226,7 @@ class FlaxModelTesterMixin:
...
@@ -229,7 +226,7 @@ class FlaxModelTesterMixin:
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
1
e-
3
)
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
4
e-
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
fx_model
.
save_pretrained
(
tmpdirname
)
fx_model
.
save_pretrained
(
tmpdirname
)
...
@@ -242,8 +239,7 @@ class FlaxModelTesterMixin:
...
@@ -242,8 +239,7 @@ class FlaxModelTesterMixin:
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
)
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs_loaded
):
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs_loaded
):
if
not
isinstance
(
fx_output
,
tuple
):
# TODO(Patrick, Daniel) - let's discard use_cache for now
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
4e-2
)
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
5e-3
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment