Unverified Commit ffd19ee1 authored by Dahlbomii's avatar Dahlbomii Committed by GitHub
Browse files

TF GPT-J Type hints and TF decorator (#16488)



* Type hints and TF decorator added

* Type hints and TF decorator added

* make style
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 277d49a5
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 GPT-J model.""" """ TF 2.0 GPT-J model."""
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -33,6 +34,7 @@ from ...modeling_tf_outputs import ( ...@@ -33,6 +34,7 @@ from ...modeling_tf_outputs import (
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
...@@ -40,6 +42,7 @@ from ...modeling_tf_utils import ( ...@@ -40,6 +42,7 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -670,6 +673,7 @@ class TFGPTJModel(TFGPTJPreTrainedModel): ...@@ -670,6 +673,7 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.transformer = TFGPTJMainLayer(config, name="transformer") self.transformer = TFGPTJMainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -679,18 +683,18 @@ class TFGPTJModel(TFGPTJPreTrainedModel): ...@@ -679,18 +683,18 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -698,9 +702,8 @@ class TFGPTJModel(TFGPTJPreTrainedModel): ...@@ -698,9 +702,8 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past`). Set to `False` during training, `True` during generation `past`). Set to `False` during training, `True` during generation
""" """
inputs = input_processing(
func=self.call, outputs = self.transformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -713,21 +716,6 @@ class TFGPTJModel(TFGPTJPreTrainedModel): ...@@ -713,21 +716,6 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -793,6 +781,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -793,6 +781,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
"use_cache": use_cache, "use_cache": use_cache,
} }
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -802,19 +791,19 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -802,19 +791,19 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -823,9 +812,8 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -823,9 +812,8 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing(
func=self.call, transformer_outputs = self.transformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -833,39 +821,23 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -833,39 +821,23 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = lm_logits[:, :-1] shifted_logits = lm_logits[:, :-1]
labels = inputs["labels"][:, 1:] labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits) loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -914,6 +886,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -914,6 +886,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
name="score", name="score",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -923,19 +896,19 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -923,19 +896,19 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -944,9 +917,8 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -944,9 +917,8 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
inputs = input_processing(
func=self.call, transformer_outputs = self.transformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -954,27 +926,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -954,27 +926,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
logits = self.score(hidden_states) logits = self.score(hidden_states)
...@@ -983,12 +939,12 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -983,12 +939,12 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
if self.config.pad_token_id is None: if self.config.pad_token_id is None:
sequence_lengths = -1 sequence_lengths = -1
else: else:
if inputs["input_ids"] is not None: if input_ids is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.reduce_sum(
tf.cast( tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=inputs["input_ids"].dtype, dtype=input_ids.dtype,
), ),
-1, -1,
keepdims=False, keepdims=False,
...@@ -1004,7 +960,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -1004,7 +960,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
) )
loss = None loss = None
if inputs["labels"] is not None: if labels is not None:
assert ( assert (
self.config.pad_token_id is not None or logits_shape[0] == 1 self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined." ), "Cannot handle batch sizes > 1 if no padding token is defined."
...@@ -1012,12 +968,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific ...@@ -1012,12 +968,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
if not tf.is_tensor(sequence_lengths): if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths] in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.hf_compute_loss( loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels])
)
pooled_logits = in_logits if in_logits is not None else logits pooled_logits = in_logits if in_logits is not None else logits
if not inputs["return_dict"]: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1057,6 +1011,7 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1057,6 +1011,7 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1066,19 +1021,19 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1066,19 +1021,19 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
start_positions=None, start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions=None, end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -1091,9 +1046,8 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1091,9 +1046,8 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing(
func=self.call, transformer_outputs = self.transformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1101,26 +1055,10 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1101,26 +1055,10 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = transformer_outputs[0] sequence_output = transformer_outputs[0]
...@@ -1130,12 +1068,12 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss) ...@@ -1130,12 +1068,12 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:] output = (start_logits, end_logits) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment