Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8d57c424
Unverified
Commit
8d57c424
authored
Apr 06, 2022
by
Sanchit Gandhi
Committed by
GitHub
Apr 06, 2022
Browse files
[FlaxSpeechEncoderDecoderModel] More Rigorous PT-Flax Equivalence Tests (#16589)
parent
c6563315
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
17 deletions
+10
-17
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
...oder_decoder/test_modeling_flax_speech_encoder_decoder.py
+10
-17
No files found.
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
View file @
8d57c424
...
...
@@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin:
pt_inputs
=
{
k
:
torch
.
tensor
(
v
.
tolist
())
for
k
,
v
in
flax_inputs
.
items
()}
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
pt_logits
=
pt_outputs
.
logits
pt_outputs
=
pt_outputs
.
to_tuple
()
fx_outputs
=
fx_model
(
**
inputs_dict
)
fx_logits
=
fx_outputs
.
logits
fx_outputs
=
fx_outputs
.
to_tuple
()
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
fx_outputs
=
fx_model
(
**
inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits
,
pt_logits
.
numpy
(),
4e-2
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
1e-5
)
# PT -> Flax
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
)
fx_model_loaded
=
FlaxSpeechEncoderDecoderModel
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
fx_outputs_loaded
=
fx_model_loaded
(
**
inputs_dict
)
fx_logits_loaded
=
fx_outputs_loaded
.
logits
fx_outputs_loaded
=
fx_outputs_loaded
.
to_tuple
()
fx_outputs_loaded
=
fx_model_loaded
(
**
inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs_loaded
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits_loaded
,
pt_logits
.
numpy
(),
4e-2
)
for
fx_output_loaded
,
pt_output
in
zip
(
fx_outputs_loaded
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output_loaded
,
pt_output
.
numpy
(),
1e-5
)
# Flax -> PT
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
...
@@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin:
pt_model_loaded
.
eval
()
with
torch
.
no_grad
():
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
)
pt_logits_loaded
=
pt_outputs_loaded
.
logits
pt_outputs_loaded
=
pt_outputs_loaded
.
to_tuple
()
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits
,
pt_logits_loaded
.
numpy
(),
4e-2
)
for
fx_output
,
pt_output_loaded
in
zip
(
fx_outputs
,
pt_outputs_loaded
):
self
.
assert_almost_equals
(
fx_output
,
pt_output_loaded
.
numpy
(),
1e-5
)
def
check_equivalence_pt_to_flax
(
self
,
config
,
decoder_config
,
inputs_dict
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment