Unverified Commit 5a2b77a6 authored by Gerald Cuder's avatar Gerald Cuder Committed by GitHub
Browse files

Fix error in mixed precision training of `TFCvtModel` (#22267)



* Make sure CVT can be trained using mixed precision

* Add test for keras-fit with mixed-precision

* Update tests/models/cvt/test_modeling_tf_cvt.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

---------
Co-authored-by: default avatargcuder <Gerald.Cuder@iacapps.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 330d8b99
......@@ -93,7 +93,7 @@ class TFCvtDropPath(tf.keras.layers.Layer):
return x
keep_prob = 1 - self.drop_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
......
......@@ -186,6 +186,12 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_keras_fit(self):
super().test_keras_fit()
def test_keras_fit_mixed_precision(self):
policy = tf.keras.mixed_precision.Policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy(policy)
super().test_keras_fit()
tf.keras.mixed_precision.set_global_policy("float32")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment