Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8d57c424
Unverified
Commit
8d57c424
authored
Apr 06, 2022
by
Sanchit Gandhi
Committed by
GitHub
Apr 06, 2022
Browse files
[FlaxSpeechEncoderDecoderModel] More Rigorous PT-Flax Equivalence Tests (#16589)
parent
c6563315
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
17 deletions
+10
-17
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
...oder_decoder/test_modeling_flax_speech_encoder_decoder.py
+10
-17
No files found.
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
View file @
8d57c424
...
@@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin:
...
@@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin:
pt_inputs
=
{
k
:
torch
.
tensor
(
v
.
tolist
())
for
k
,
v
in
flax_inputs
.
items
()}
pt_inputs
=
{
k
:
torch
.
tensor
(
v
.
tolist
())
for
k
,
v
in
flax_inputs
.
items
()}
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
)
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
pt_logits
=
pt_outputs
.
logits
pt_outputs
=
pt_outputs
.
to_tuple
()
fx_outputs
=
fx_model
(
**
inputs_dict
)
fx_logits
=
fx_outputs
.
logits
fx_outputs
=
fx_outputs
.
to_tuple
()
fx_outputs
=
fx_model
(
**
inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits
,
pt_logits
.
numpy
(),
4e-2
)
for
fx_output
,
pt_output
in
zip
(
fx_outputs
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output
,
pt_output
.
numpy
(),
1e-5
)
# PT -> Flax
# PT -> Flax
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_model
.
save_pretrained
(
tmpdirname
)
pt_model
.
save_pretrained
(
tmpdirname
)
fx_model_loaded
=
FlaxSpeechEncoderDecoderModel
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
fx_model_loaded
=
FlaxSpeechEncoderDecoderModel
.
from_pretrained
(
tmpdirname
,
from_pt
=
True
)
fx_outputs_loaded
=
fx_model_loaded
(
**
inputs_dict
)
fx_outputs_loaded
=
fx_model_loaded
(
**
inputs_dict
).
to_tuple
()
fx_logits_loaded
=
fx_outputs_loaded
.
logits
fx_outputs_loaded
=
fx_outputs_loaded
.
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs_loaded
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs_loaded
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits_loaded
,
pt_logits
.
numpy
(),
4e-2
)
for
fx_output_loaded
,
pt_output
in
zip
(
fx_outputs_loaded
,
pt_outputs
):
self
.
assert_almost_equals
(
fx_output_loaded
,
pt_output
.
numpy
(),
1e-5
)
# Flax -> PT
# Flax -> PT
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
...
@@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin:
...
@@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin:
pt_model_loaded
.
eval
()
pt_model_loaded
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
)
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
).
to_tuple
()
pt_logits_loaded
=
pt_outputs_loaded
.
logits
pt_outputs_loaded
=
pt_outputs_loaded
.
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assert_almost_equals
(
fx_logits
,
pt_logits_loaded
.
numpy
(),
4e-2
)
for
fx_output
,
pt_output_loaded
in
zip
(
fx_outputs
,
pt_outputs_loaded
):
self
.
assert_almost_equals
(
fx_output
,
pt_output_loaded
.
numpy
(),
1e-5
)
def
check_equivalence_pt_to_flax
(
self
,
config
,
decoder_config
,
inputs_dict
):
def
check_equivalence_pt_to_flax
(
self
,
config
,
decoder_config
,
inputs_dict
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment