".github/vscode:/vscode.git/clone" did not exist on "4be8b95a9f0665a94589231cf4dd50d379ac1710"
Unverified Commit ee18d4d2 authored by Christopher Akiki's avatar Christopher Akiki Committed by GitHub
Browse files

TF GPT2: clearer model variable naming with @unpack_inputs (#16311)

* add unpack_inputs decorator to Main Layer

* add unpack_inputs decorator to Model

* add unpack_inputs decorator to LMHead Model

* add unpack_inputs decorator to Double Head Model

* add unpack_inputs decorator to Sequence Classification Model

* run fixup recipe

* make unpack_inputs the first decorator
parent d7c8ce57
......@@ -37,8 +37,8 @@ from ...modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
......@@ -350,6 +350,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
input_ids: Optional[TFModelInputType] = None,
......@@ -368,55 +369,34 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["past"] is None:
if past is None:
past_length = 0
inputs["past"] = [None] * len(self.h)
past = [None] * len(self.h)
else:
past_length = shape_list(inputs["past"][0][0])[-2]
past_length = shape_list(past[0][0])[-2]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None:
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"])
inputs["attention_mask"] = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
attention_mask_shape = shape_list(attention_mask)
attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -424,24 +404,20 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply(
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0))
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
if self.config.add_cross_attention and inputs["encoder_attention_mask"] is not None:
if self.config.add_cross_attention and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=inputs["encoder_hidden_states"].dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=encoder_hidden_states.dtype)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
......@@ -452,66 +428,64 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
else:
encoder_extended_attention_mask = None
inputs["encoder_attention_mask"] = encoder_extended_attention_mask
encoder_attention_mask = encoder_extended_attention_mask
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.num_hidden_layers
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids, mode="embedding")
position_embeds = tf.gather(self.wpe, inputs["position_ids"])
position_embeds = tf.gather(self.wpe, position_ids)
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
if token_type_ids is not None:
token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
token_type_embeds = self.wte(token_type_ids, mode="embedding")
else:
token_type_embeds = tf.constant(0.0)
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=inputs["training"])
position_embeds = tf.cast(position_embeds, dtype=inputs_embeds.dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype)
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states, training=training)
output_shape = input_shape + [shape_list(hidden_states)[-1]]
presents = () if inputs["use_cache"] else None
all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if inputs["output_attentions"] and self.config.add_cross_attention else None
all_hidden_states = () if inputs["output_hidden_states"] else None
for i, (block, layer_past) in enumerate(zip(self.h, inputs["past"])):
if inputs["output_hidden_states"]:
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(
hidden_states,
layer_past,
inputs["attention_mask"],
inputs["head_mask"][i],
inputs["encoder_hidden_states"],
inputs["encoder_attention_mask"],
inputs["use_cache"],
inputs["output_attentions"],
training=inputs["training"],
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
training=training,
)
hidden_states, present = outputs[:2]
if inputs["use_cache"]:
if use_cache:
presents = presents + (present,)
if inputs["output_attentions"]:
if output_attentions:
all_attentions = all_attentions + (outputs[2],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (outputs[3],)
......@@ -520,15 +494,15 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state
if inputs["output_hidden_states"]:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if inputs["output_attentions"]:
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
if not inputs["return_dict"]:
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions, all_cross_attentions]
......@@ -732,6 +706,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -777,9 +752,8 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past`). Set to `False` during training, `True` during generation
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.transformer(
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
......@@ -794,23 +768,6 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -938,6 +895,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
return model_kwargs
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -987,9 +945,8 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
config.vocab_size - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
......@@ -1003,37 +960,19 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
logits = self.transformer.wte(hidden_states, mode="linear")
loss = None
if inputs["labels"] is not None:
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1]
labels = inputs["labels"][:, 1:]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1081,6 +1020,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
config, initializer_range=config.initializer_range, name="multiple_choice_head"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -1133,64 +1073,40 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
>>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
mc_token_ids=mc_token_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None:
input_shapes = shape_list(inputs["input_ids"])
if input_ids is not None:
input_shapes = shape_list(input_ids)
else:
input_shapes = shape_list(inputs["inputs_embeds"])[:-1]
input_shapes = shape_list(inputs_embeds)[:-1]
seq_length = input_shapes[-1]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
transformer_outputs = self.transformer(
input_ids=flat_input_ids,
past=inputs["past"],
past=past,
attention_mask=flat_attention_mask,
token_type_ids=flat_token_type_ids,
position_ids=flat_position_ids,
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"])
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1)
if not inputs["return_dict"]:
if not return_dict:
return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFGPT2DoubleHeadsModelOutput(
......@@ -1256,6 +1172,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
)
self.transformer = TFGPT2MainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1285,9 +1202,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
config.vocab_size - 1]`.
"""
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
......@@ -1299,24 +1214,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
......@@ -1326,12 +1224,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if inputs["input_ids"] is not None:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
......@@ -1347,7 +1245,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
)
loss = None
if inputs["labels"] is not None:
if labels is not None:
assert (
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
......@@ -1355,12 +1253,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.hf_compute_loss(
tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels])
)
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
if not inputs["return_dict"]:
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment