"docs/source/vscode:/vscode.git/clone" did not exist on "4c7e8d09008ea4e46dd09dccfbd518bb2b792e75"
Unverified Commit 1da84ae0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix Bug in Flax-Speech-Encoder-Decoder Test (#16041)

* Fix Bug in Flax-Speech-Encoder-Decoder Test

* change thresholds for CPU precision
parent b2a1c994
...@@ -303,14 +303,12 @@ class FlaxEncoderDecoderMixin: ...@@ -303,14 +303,12 @@ class FlaxEncoderDecoderMixin:
inputs, inputs,
attention_mask, attention_mask,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask,
freeze_feature_encoder: bool = False, freeze_feature_encoder: bool = False,
): ):
outputs_enc_dec = enc_dec_model( outputs_enc_dec = enc_dec_model(
inputs=inputs, inputs=inputs,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
freeze_feature_encoder=freeze_feature_encoder, freeze_feature_encoder=freeze_feature_encoder,
params=params, params=params,
) )
...@@ -323,13 +321,11 @@ class FlaxEncoderDecoderMixin: ...@@ -323,13 +321,11 @@ class FlaxEncoderDecoderMixin:
grad_fn = jax.value_and_grad(compute_loss) grad_fn = jax.value_and_grad(compute_loss)
# compute the loss and gradients for the unfrozen model # compute the loss and gradients for the unfrozen model
loss, grads = grad_fn( loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False)
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
)
# compare to the loss and gradients for the frozen model # compare to the loss and gradients for the frozen model
loss_frozen, grads_frozen = grad_fn( loss_frozen, grads_frozen = grad_fn(
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True
) )
self.assert_almost_equals(loss, loss_frozen, 1e-5) self.assert_almost_equals(loss, loss_frozen, 1e-5)
...@@ -348,14 +344,14 @@ class FlaxEncoderDecoderMixin: ...@@ -348,14 +344,14 @@ class FlaxEncoderDecoderMixin:
feature_extractor_grads, feature_extractor_grads_frozen feature_extractor_grads, feature_extractor_grads_frozen
): ):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all()) self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8) self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-10)
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor' # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k) grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k) grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
for grad, grad_frozen in zip(grads, grads_frozen): for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-8) self.assert_almost_equals(grad, grad_frozen, 1e-10)
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment