"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a81cf9ee90c2e3c802b6454e84a4545876382b7d"
Unverified Commit ec6cd763 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: Add missing cast to GPT-J (#18201)

* Fix TF GPT-J tests

* add try/finally block
parent 05ed569c
......@@ -222,7 +222,7 @@ class TFGPTJAttention(tf.keras.layers.Layer):
key = self._split_heads(key, True)
value = self._split_heads(value, False)
sincos = tf.gather(self.embed_positions, position_ids, axis=0)
sincos = tf.cast(tf.gather(self.embed_positions, position_ids, axis=0), hidden_states.dtype)
sincos = tf.split(sincos, 2, axis=-1)
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
......
......@@ -274,16 +274,17 @@ class TFCoreModelTesterMixin:
def test_mixed_precision(self):
tf.keras.mixed_precision.set_global_policy("mixed_float16")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
self.assertIsNotNone(outputs)
# try/finally block to ensure subsequent tests run in float32
try:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
outputs = model(class_inputs_dict)
tf.keras.mixed_precision.set_global_policy("float32")
self.assertIsNotNone(outputs)
finally:
tf.keras.mixed_precision.set_global_policy("float32")
@slow
def test_train_pipeline_custom_model(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment