Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e231c729
Unverified
Commit
e231c729
authored
Mar 25, 2022
by
Sanchit Gandhi
Committed by
GitHub
Mar 25, 2022
Browse files
[FlaxSpeechEncoderDecoder] Fix feature extractor gradient test (#16407)
parent
a97f3150
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
...oder_decoder/test_modeling_flax_speech_encoder_decoder.py
+3
-3
No files found.
tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py
View file @
e231c729
...
...
@@ -389,14 +389,14 @@ class FlaxEncoderDecoderMixin:
feature_extractor_grads
,
feature_extractor_grads_frozen
):
self
.
assertTrue
((
feature_extractor_grad_frozen
==
0.0
).
all
())
self
.
assert_difference
(
feature_extractor_grad
,
feature_extractor_grad_frozen
,
1e-
10
)
self
.
assert_difference
(
feature_extractor_grad
,
feature_extractor_grad_frozen
,
1e-
5
)
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads
=
tuple
(
grads
[
k
]
for
k
in
grads
if
"feature_extractor"
not
in
k
)
grads_frozen
=
tuple
(
grads_frozen
[
k
]
for
k
in
grads_frozen
if
"feature_extractor"
not
in
k
)
for
grad
,
grad_frozen
in
zip
(
grads
,
grads_frozen
):
self
.
assert_almost_equals
(
grad
,
grad_frozen
,
1e-
10
)
self
.
assert_almost_equals
(
grad
,
grad_frozen
,
1e-
5
)
def
check_pt_flax_equivalence
(
self
,
pt_model
,
fx_model
,
inputs_dict
):
...
...
@@ -507,7 +507,7 @@ class FlaxEncoderDecoderMixin:
self
.
assertLessEqual
(
diff
,
tol
,
f
"Difference between arrays is
{
diff
}
(>=
{
tol
}
)."
)
def
assert_difference
(
self
,
a
:
np
.
ndarray
,
b
:
np
.
ndarray
,
tol
:
float
):
diff
=
np
.
abs
((
a
-
b
)).
m
in
()
diff
=
np
.
abs
((
a
-
b
)).
m
ax
()
self
.
assertGreaterEqual
(
diff
,
tol
,
f
"Difference between arrays is
{
diff
}
(<=
{
tol
}
)."
)
@
is_pt_flax_cross_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment