Unverified Commit 3c4fbc61 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Freeze FlaxWav2Vec2 Feature Encoder (#15873)

* Freeze FlaxWav2Vec2 Feature Encoder

* add to all module apply

* add backprop test
parent 7b3bd1f2
......@@ -404,9 +404,11 @@ class FlaxWav2Vec2FeatureEncoder(nn.Module):
def setup(self):
self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
def __call__(self, input_values):
def __call__(self, input_values, freeze_feature_encoder=False):
hidden_states = input_values[:, :, None]
hidden_states = self.conv_layers(hidden_states)
if freeze_feature_encoder:
hidden_states = jax.lax.stop_gradient(hidden_states)
return hidden_states
......@@ -875,6 +877,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -903,6 +906,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
......@@ -939,9 +943,10 @@ class FlaxWav2Vec2Module(nn.Module):
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
extract_features = self.feature_extractor(input_values)
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
# make sure that no loss is computed on padded inputs
if attention_mask is not None:
......@@ -1101,6 +1106,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
outputs = self.wav2vec2(
......@@ -1110,6 +1116,7 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
......@@ -1232,6 +1239,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_enocder=False,
return_dict=None,
):
r"""
......@@ -1252,6 +1260,7 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
output_hidden_states=output_hidden_states,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
freeze_feature_encoder=freeze_feature_enocder,
return_dict=return_dict,
)
......@@ -1310,6 +1319,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
freeze_feature_encoder: bool = False,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......@@ -1342,6 +1352,7 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
not train,
output_attentions,
output_hidden_states,
freeze_feature_encoder,
return_dict,
rngs=rngs,
)
......
......@@ -229,6 +229,47 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
def test_freeze_feature_encoder(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_values = inputs_dict["input_values"]
attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config)
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=False,
)
outputs_frozen = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=True,
)
# dummy loss function
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8):
# compute cosine similarity of projected and projected_quantized states
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon)
loss = cosine_sim.sum()
return loss
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
# compute loss and gradients for unfrozen model
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states)
# compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states)
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5)
self.assertEqual(grads.shape, grads_frozen.shape)
max_diff = np.amax(np.abs(grads - grads_frozen))
self.assertLessEqual(max_diff, 1e-5)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment