Unverified Commit 1a62b25c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder (#15938)

* Backprop Test for Freeze FlaxWav2Vec2 Feature Encoder

* remove jnp.ndarray type suggestion

* assert frozen grads are precisely zero
parent 544fd987
...@@ -37,6 +37,7 @@ if is_flax_available(): ...@@ -37,6 +37,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
from flax.traverse_util import flatten_dict
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import ( from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTC,
...@@ -236,23 +237,22 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -236,23 +237,22 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
attention_mask = inputs_dict["attention_mask"] attention_mask = inputs_dict["attention_mask"]
model = FlaxWav2Vec2ForPreTraining(config) model = FlaxWav2Vec2ForPreTraining(config)
params = model.params
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=False,
)
outputs_frozen = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=True,
)
# dummy loss function # dummy loss function
def compute_loss(projected_states, projected_quantized_states, epsilon=1e-8): def compute_loss(
params, input_values, attention_mask, freeze_feature_encoder: bool = False, epsilon: float = 1e-8
):
outputs = model(
input_values,
attention_mask=attention_mask,
freeze_feature_encoder=freeze_feature_encoder,
params=params,
)
# compute cosine similarity of projected and projected_quantized states # compute cosine similarity of projected and projected_quantized states
cosine_sim = optax.cosine_similarity(projected_states, projected_quantized_states, epsilon=epsilon) cosine_sim = optax.cosine_similarity(
outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon
)
loss = cosine_sim.sum() loss = cosine_sim.sum()
return loss return loss
...@@ -260,15 +260,43 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -260,15 +260,43 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
grad_fn = jax.value_and_grad(compute_loss) grad_fn = jax.value_and_grad(compute_loss)
# compute loss and gradients for unfrozen model # compute loss and gradients for unfrozen model
loss, grads = grad_fn(outputs.projected_states, outputs.projected_quantized_states) loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False)
# compare to loss and gradients for frozen model # compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(outputs_frozen.projected_states, outputs_frozen.projected_quantized_states) loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True)
self.assert_almost_equals(loss, loss_frozen, 1e-5)
grads = flatten_dict(grads)
grads_frozen = flatten_dict(grads_frozen)
# ensure that the dicts of gradients contain the same keys
self.assertEqual(grads.keys(), grads_frozen.keys())
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
for feature_extractor_grad, feature_extractor_grad_frozen in zip(
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7)
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-7)
def assert_difference(self, a, b, tol: float):
diff = jnp.abs((a - b)).min()
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
self.assertLessEqual(np.abs(loss - loss_frozen), 1e-5) def assert_almost_equals(self, a, b, tol: float):
self.assertEqual(grads.shape, grads_frozen.shape) diff = jnp.abs((a - b)).max()
max_diff = np.amax(np.abs(grads - grads_frozen)) self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
self.assertLessEqual(max_diff, 1e-5)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment