Unverified Commit bdf1669e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF GPT2 compliant with XLA and AMP (#10230)

* Fix XLA and AMP

* Fix AMP and XLA

* Apply style

* Apply Patrick's comment
parent 5da7c78e
...@@ -1331,119 +1331,6 @@ class TFConv1D(tf.keras.layers.Layer): ...@@ -1331,119 +1331,6 @@ class TFConv1D(tf.keras.layers.Layer):
return x return x
class WordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.word_embeddings = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape=input_shape)
def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
return embeddings
class TokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape=input_shape)
def get_config(self):
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, token_type_ids):
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
return embeddings
class PositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)
super().build(input_shape)
def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
class TFSharedEmbeddings(tf.keras.layers.Layer): class TFSharedEmbeddings(tf.keras.layers.Layer):
r""" r"""
Construct shared token embeddings. Construct shared token embeddings.
......
...@@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer):
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask w = w + attention_mask
w = tf.nn.softmax(w, axis=-1) w = tf.nn.softmax(w, axis=-1)
...@@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.n_layer self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.n_positions = config.n_positions
self.initializer_range = config.initializer_range
self.wte = TFSharedEmbeddings( self.wte = TFSharedEmbeddings(
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
) )
self.wpe = tf.keras.layers.Embedding(
config.n_positions,
config.n_embd,
embeddings_initializer=get_initializer(config.initializer_range),
name="wpe",
)
self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)] self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
def build(self, input_shape):
with tf.name_scope("wpe"):
self.wpe = self.add_weight(
name="embeddings",
shape=[self.n_positions, self.n_embd],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.wte return self.wte
...@@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
past_length = shape_list(inputs["past"][0][0])[-2] past_length = shape_list(inputs["past"][0][0])[-2]
if inputs["position_ids"] is None: if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims( inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
)
if inputs["attention_mask"] is not None: if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 inputs["attention_mask"] = tf.multiply(
else: tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
inputs["attention_mask"] = None )
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
position_embeds = self.wpe(inputs["position_ids"]) position_embeds = tf.gather(self.wpe, inputs["position_ids"])
if inputs["token_type_ids"] is not None: if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape( inputs["token_type_ids"] = tf.reshape(
...@@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
) )
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = 0 token_type_embeds = tf.constant(0.0)
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype) position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
...@@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific ...@@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1, -1,
keepdims=False, keepdims=False,
) )
......
...@@ -389,14 +389,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -389,14 +389,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
def test_mixed_precision(self):
# TODO JP: Make GPT2 float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make GPT2 XLA compliant
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment