Unverified Commit 7246785a authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Make TF CTRL compliant with XLA and AMP (#10209)

* Fix XLA and AMP

* Apply style

* Remove useless cast
parent fdb2351e
...@@ -48,7 +48,7 @@ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -48,7 +48,7 @@ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
def angle_defn(pos, i, d_model_size): def angle_defn(pos, i, d_model_size):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size)) angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
return pos * angle_rates return pos * angle_rates
...@@ -58,9 +58,8 @@ def positional_encoding(position, d_model_size): ...@@ -58,9 +58,8 @@ def positional_encoding(position, d_model_size):
sines = np.sin(angle_rads[:, 0::2]) sines = np.sin(angle_rads[:, 0::2])
cosines = np.cos(angle_rads[:, 1::2]) cosines = np.cos(angle_rads[:, 1::2])
pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
return pos_encoding return pos_encoding
...@@ -68,14 +67,15 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N ...@@ -68,14 +67,15 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
# calculate attention # calculate attention
matmul_qk = tf.matmul(q, k, transpose_b=True) matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(shape_list(k)[-1], tf.float32) dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None: if mask is not None:
scaled_attention_logits += mask * -1e4 scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
scaled_attention_logits = scaled_attention_logits + attention_mask scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
...@@ -332,10 +332,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -332,10 +332,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) one_cst = tf.constant(1.0)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 ten_thousand_cst = tf.constant(-10000.0)
else: inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = None inputs["attention_mask"] = tf.multiply(tf.subtract(one_cst, inputs["attention_mask"]), ten_thousand_cst)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -351,9 +351,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -351,9 +351,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
) )
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding") token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
else: else:
token_type_embeds = 0 token_type_embeds = tf.constant(0.0)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
...@@ -361,10 +361,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -361,10 +361,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
seq_len = input_shape[-1] seq_len = input_shape[-1]
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32)) inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, inputs["inputs_embeds"].dtype))
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"]) pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
...@@ -857,7 +857,6 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific ...@@ -857,7 +857,6 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
logits = self.classifier(hidden_states) logits = self.classifier(hidden_states)
logits_shape = shape_list(logits)
in_logits = None in_logits = None
if self.config.pad_token_id is None: if self.config.pad_token_id is None:
sequence_lengths = -1 sequence_lengths = -1
...@@ -865,22 +864,16 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific ...@@ -865,22 +864,16 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1, -1,
keepdims=False, keepdims=False,
) )
- 1 - 1
) )
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -222,14 +222,6 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -222,14 +222,6 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
def test_mixed_precision(self):
# TODO JP: Make CTRL float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make CTRL XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment