".github/vscode:/vscode.git/clone" did not exist on "aa08a34669711295eccc4b2e7e41bde16f1af20d"
Unverified Commit 925fc57b authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax] Improve Robustness of Back-Prop Tests (#16418)

* [Flax] Improve Robustness of Back-Prop Tests

* check equality of logits/outputs

* make fixup
parent 7ecbb9c5
......@@ -360,20 +360,24 @@ class FlaxEncoderDecoderMixin:
logits = outputs_enc_dec.logits
vocab_size = logits.shape[-1]
loss = cross_entropy(logits, onehot(labels=decoder_input_ids, num_classes=vocab_size)).sum()
return loss
return (loss, logits)
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
# compute the loss and gradients for the unfrozen model
loss, grads = grad_fn(params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False)
# compute the loss, logits, and gradients for the unfrozen model
(loss, logits), grads = grad_fn(
params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=False
)
# compare to the loss and gradients for the frozen model
loss_frozen, grads_frozen = grad_fn(
# compare to the loss, logits and gradients for the frozen model
(loss_frozen, logits_frozen), grads_frozen = grad_fn(
params, inputs, attention_mask, decoder_input_ids, freeze_feature_encoder=True
)
self.assert_almost_equals(loss, loss_frozen, 1e-5)
# ensure that the logits and losses remain precisely equal
self.assertTrue((logits == logits_frozen).all())
self.assertEqual(loss, loss_frozen)
grads = flatten_dict(grads)
grads_frozen = flatten_dict(grads_frozen)
......@@ -381,7 +385,7 @@ class FlaxEncoderDecoderMixin:
# ensure that the dicts of gradients contain the same keys
self.assertEqual(grads.keys(), grads_frozen.keys())
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
# ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
......@@ -389,14 +393,14 @@ class FlaxEncoderDecoderMixin:
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-5)
self.assertTrue((feature_extractor_grad > 0.0).any())
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
# ensure that the gradients of all unfrozen layers remain precisely equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-5)
self.assertTrue((grad == grad_frozen).all())
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
......@@ -504,11 +508,7 @@ class FlaxEncoderDecoderMixin:
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
def assert_difference(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
......
......@@ -254,18 +254,23 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
outputs.projected_states, outputs.projected_quantized_states, epsilon=epsilon
)
loss = cosine_sim.sum()
return loss
return loss, outputs.to_tuple()
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
# compute loss and gradients for unfrozen model
loss, grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False)
# compute loss, outputs and gradients for unfrozen model
(loss, outputs), grads = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=False)
# compare to loss and gradients for frozen model
loss_frozen, grads_frozen = grad_fn(params, input_values, attention_mask, freeze_feature_encoder=True)
# compare to loss, outputs and gradients for frozen model
(loss_frozen, outputs_frozen), grads_frozen = grad_fn(
params, input_values, attention_mask, freeze_feature_encoder=True
)
self.assert_almost_equals(loss, loss_frozen, 1e-5)
# ensure that the outputs and losses remain precisely equal
for output, output_frozen in zip(outputs, outputs_frozen):
self.assertTrue((output == output_frozen).all())
self.assertEqual(loss, loss_frozen)
grads = flatten_dict(grads)
grads_frozen = flatten_dict(grads_frozen)
......@@ -273,7 +278,7 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
# ensure that the dicts of gradients contain the same keys
self.assertEqual(grads.keys(), grads_frozen.keys())
# ensure that the gradients of the frozen layers are precisely zero and that they differ to the gradients of the unfrozen layers
# ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen
feature_extractor_grads = tuple(grads[k] for k in grads if "feature_extractor" in k)
feature_extractor_grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" in k)
......@@ -281,22 +286,14 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
feature_extractor_grads, feature_extractor_grads_frozen
):
self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
self.assert_difference(feature_extractor_grad, feature_extractor_grad_frozen, 1e-7)
self.assertTrue((feature_extractor_grad > 0.0).any())
# ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
grads_frozen = tuple(grads_frozen[k] for k in grads_frozen if "feature_extractor" not in k)
for grad, grad_frozen in zip(grads, grads_frozen):
self.assert_almost_equals(grad, grad_frozen, 1e-7)
def assert_difference(self, a, b, tol: float):
diff = jnp.abs((a - b)).min()
self.assertGreaterEqual(diff, tol, f"Difference between arrays is {diff} (<= {tol}).")
def assert_almost_equals(self, a, b, tol: float):
diff = jnp.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between arrays is {diff} (>= {tol}).")
self.assertTrue((grad == grad_frozen).all())
@slow
def test_model_from_pretrained(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment