Unverified Commit 34df26ec authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF OpenAI GPT model compliant with AMP and XLA (#10261)

* Fix AMP and XLA

* Remove useless var
parent 3e116ed3
......@@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer):
pass
@staticmethod
def causal_attention_mask(nd, ns, dtype):
def causal_attention_mask(nd, ns):
"""
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
-1, ns-nd), but doesn't produce garbage on TPUs.
......@@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer):
i = tf.range(nd)[:, None]
j = tf.range(ns)
m = i >= j - ns + nd
return tf.cast(m, dtype)
return m
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
w = w / tf.math.sqrt(dk)
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_, _, nd, ns = shape_list(w)
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
b = tf.reshape(b, [1, 1, nd, ns])
w = w * b - 1e4 * (1 - b)
if attention_mask is not None:
# Apply the attention mask
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask
w = tf.nn.softmax(w, axis=-1)
......@@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.n_positions = config.n_positions
self.initializer_range = config.initializer_range
self.tokens_embed = TFSharedEmbeddings(
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
)
self.positions_embed = tf.keras.layers.Embedding(
config.n_positions,
config.n_embd,
embeddings_initializer=get_initializer(config.initializer_range),
name="positions_embed",
)
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
def build(self, input_shape):
with tf.name_scope("positions_embed"):
self.positions_embed = self.add_weight(
name="embeddings",
shape=[self.n_positions, self.n_embd],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_input_embeddings(self):
return self.tokens_embed
......@@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1], dtype=tf.int32), axis=0)
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
......@@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply(
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
else:
inputs["attention_mask"] = None
......@@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
position_embeds = self.positions_embed(inputs["position_ids"])
position_embeds = tf.gather(self.positions_embed, inputs["position_ids"])
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
......@@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
if self.config.pad_token_id is None:
sequence_lengths = -1
......@@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1,
keepdims=False,
)
- 1
)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
def test_mixed_precision(self):
# TODO JP: Make OpenAIGPT float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make OpenAIGPT XLA compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment