Unverified Commit c2f8eaf6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491)

* Add unpack_inputs to remaining models

* remove stray use of inputs in the templates; fix tf.debugging of attn masks
parent ae189ef9
...@@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -943,12 +943,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
...@@ -40,7 +40,6 @@ from ...modeling_tf_utils import ( ...@@ -40,7 +40,6 @@ from ...modeling_tf_utils import (
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFSharedEmbeddings, TFSharedEmbeddings,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
) )
...@@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -376,6 +375,7 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -392,53 +392,34 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["past_key_values"] is None: if past_key_values is None:
past_length = 0 past_length = 0
inputs["past_key_values"] = [None] * len(self.h) past_key_values = [None] * len(self.h)
else: else:
past_length = shape_list(inputs["past_key_values"][0][0])[-2] past_length = shape_list(past_key_values[0][0])[-2]
if inputs["position_ids"] is None: if position_ids is None:
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None: if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
inputs["attention_mask"] = tf.reshape( attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer): ...@@ -446,78 +427,74 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
one_cst = tf.constant(1.0) one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply( attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.num_hidden_layers head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers) # head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs["inputs_embeds"] is None: if inputs_embeds is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") inputs_embeds = self.wte(input_ids, mode="embedding")
if inputs["token_type_ids"] is not None: if token_type_ids is not None:
inputs["token_type_ids"] = tf.reshape( token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] token_type_embeds = self.wte(token_type_ids, mode="embedding")
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else: else:
token_type_embeds = tf.constant(0.0) token_type_embeds = tf.constant(0.0)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs["inputs_embeds"] + token_type_embeds hidden_states = inputs_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"]) hidden_states = self.drop(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]] output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None presents = () if use_cache else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if output_attentions else None
all_hidden_states = () if inputs["output_hidden_states"] else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past, layer_past,
inputs["attention_mask"], attention_mask,
inputs["head_mask"][i], head_mask[i],
inputs["use_cache"], use_cache,
inputs["output_attentions"], output_attentions,
training=inputs["training"], training=training,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if inputs["use_cache"]: if use_cache:
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if inputs["output_attentions"]: if output_attentions:
all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],) all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape) hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state # Add last hidden state
if inputs["output_hidden_states"]: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if inputs["output_attentions"]: if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutputWithPast( return TFBaseModelOutputWithPast(
......
...@@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -965,12 +965,12 @@ class TFMBartDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they # The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager. # have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
...@@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1060,12 +1060,12 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]: for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly(): if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(attn_mask)[0], shape_list(attn_mask)[0],
len(self.layers), len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
) )
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment