"...git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "bfe194c190597ff9a77efc76a21c441bc19b6549"
Unverified Commit ffd19ee1 authored by Dahlbomii's avatar Dahlbomii Committed by GitHub
Browse files

TF GPT-J Type hints and TF decorator (#16488)



* Type hints and TF decorator added

* Type hints and TF decorator added

* make style
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 277d49a5
......@@ -14,8 +14,9 @@
# limitations under the License.
""" TF 2.0 GPT-J model."""
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
......@@ -33,6 +34,7 @@ from ...modeling_tf_outputs import (
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
......@@ -40,6 +42,7 @@ from ...modeling_tf_utils import (
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
......@@ -670,6 +673,7 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFGPTJMainLayer(config, name="transformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -679,18 +683,18 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
)
def call(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
r"""
......@@ -698,9 +702,8 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past`). Set to `False` during training, `True` during generation
"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
......@@ -713,21 +716,6 @@ class TFGPTJModel(TFGPTJPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -793,6 +781,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
"use_cache": use_cache,
}
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -802,19 +791,19 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
)
def call(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
r"""
......@@ -823,9 +812,8 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
......@@ -833,39 +821,23 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if inputs["labels"] is not None:
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = lm_logits[:, :-1]
labels = inputs["labels"][:, 1:]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits)
if not inputs["return_dict"]:
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -914,6 +886,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
name="score",
)
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -923,19 +896,19 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
)
def call(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
r"""
......@@ -944,9 +917,8 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
......@@ -954,27 +926,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
......@@ -983,12 +939,12 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if inputs["input_ids"] is not None:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
......@@ -1004,7 +960,7 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
)
loss = None
if inputs["labels"] is not None:
if labels is not None:
assert (
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
......@@ -1012,12 +968,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.hf_compute_loss(
tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels])
)
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
if not inputs["return_dict"]:
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1057,6 +1011,7 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
self.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@unpack_inputs
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
......@@ -1066,19 +1021,19 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
)
def call(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
r"""
......@@ -1091,9 +1046,8 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
inputs = input_processing(
func=self.call,
config=self.config,
transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
......@@ -1101,26 +1055,10 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
start_positions=start_positions,
end_positions=end_positions,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past_key_values=inputs["past_key_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = transformer_outputs[0]
......@@ -1130,12 +1068,12 @@ class TFGPTJForQuestionAnswering(TFGPTJPreTrainedModel, TFQuestionAnsweringLoss)
end_logits = tf.squeeze(end_logits, axis=-1)
loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None:
labels = {"start_position": inputs["start_positions"]}
labels["end_position"] = inputs["end_positions"]
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]:
if not return_dict:
output = (start_logits, end_logits) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment