Unverified Commit 1e6141c3 authored by Edoardo Abati's avatar Edoardo Abati Committed by GitHub
Browse files

Add type hints to TFPegasusModel (#19858)

* added typing to call in TFPegasusModel and TFPegasusForConditionalGeneration

* fixed type for TFPegasusForConditionalGeneration call
parent ecf29db0
...@@ -33,6 +33,7 @@ from ...modeling_tf_outputs import ( ...@@ -33,6 +33,7 @@ from ...modeling_tf_outputs import (
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
keras_serializable, keras_serializable,
unpack_inputs, unpack_inputs,
...@@ -1232,25 +1233,25 @@ class TFPegasusModel(TFPegasusPreTrainedModel): ...@@ -1232,25 +1233,25 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids=None, decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
**kwargs **kwargs
): ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
...@@ -1361,25 +1362,25 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1361,25 +1362,25 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE) @add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids=None, decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: bool = False,
): ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
""" """
labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment