Unverified Commit fde90187 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Freeze Feature Encoder in FlaxSpeechEncoderDecoder (#15997)

* Freeze Feature Encoder in FlaxSpeechEncoderDecoder

* add backprop test
parent 65f9653e
......@@ -250,13 +250,6 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
def _get_decoder_module(self):
return self.decoder
def freeze_feature_encoder(self):
"""
Calling this function will disable the gradient computation for the feature encoder of the speech encoder in
order that its parameters are not updated during training.
"""
self.encoder.freeze_feature_encoder()
def __call__(
self,
inputs,
......@@ -269,6 +262,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
freeze_feature_encoder: bool = False,
):
if encoder_outputs is None:
encoder_outputs = self.encoder(
......@@ -278,6 +272,7 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
freeze_feature_encoder=freeze_feature_encoder,
)
encoder_hidden_states = encoder_outputs[0]
......@@ -448,6 +443,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
freeze_feature_encoder: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
......@@ -493,6 +489,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
freeze_feature_encoder=freeze_feature_encoder,
rngs=rngs,
method=_encoder_forward,
)
......@@ -644,6 +641,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
freeze_feature_encoder: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
......@@ -705,6 +703,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
freeze_feature_encoder=freeze_feature_encoder,
rngs=rngs,
)
......
......@@ -28,6 +28,10 @@ from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
if is_flax_available():
import jax
import jax.numpy as jnp
from flax.training.common_utils import onehot
from flax.traverse_util import flatten_dict
from transformers import (
FlaxBartForCausalLM,
FlaxGPT2LMHeadModel,
......@@ -275,6 +279,84 @@ class FlaxEncoderDecoderMixin:
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (inputs.shape[0],) + (decoder_config.max_length,))
def check_freeze_feature_encoder(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
params = enc_dec_model.params
def cross_entropy(logits, labels):
return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
# define a dummy loss function for computing the loss over a forward pass
def compute_loss(
params,
inputs,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
freeze_feature_encoder: bool = False,
):
outputs_enc_dec = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
freeze_feature_encoder=freeze_feature_encoder,
params=params,
)
logits = outputs_enc_dec.logits
vocab_size = logits.shape[-1]
loss = cross_entropy(logits, onehot(labels=decoder_input_ids, num_classes=vocab_size)).sum()
return loss
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
# compute the loss and gradients for the unfrozen model
loss, grads = grad_fn(
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=False
)
# compare to the loss and gradients for the frozen model
loss_frozen, grads_frozen = grad_fn(
params, inputs, attention_mask, decoder_input_ids, decoder_attention_mask, freeze_feature_encoder=True
)
self.assert_almost_equals(loss, loss_frozen, 1e-5)
grads = flatten_dict(grads)
grads_frozen = flatten_dict(grads_frozen)
# ensure that the dicts of gradients contain the same keys
self.assertEqual(grads.keys(), grads_frozen.keys())
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
for feature_extractor_grad, feature_extractor_grad_frozen in zip(
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-8)
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-8)
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
......@@ -367,13 +449,21 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_freeze_feature_encoder(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_freeze_feature_encoder(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).min()
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment