Unverified Commit bb3a1d34 authored by Rishav Chandra Varma's avatar Rishav Chandra Varma Committed by GitHub
Browse files

Adding missing type hints for mBART model (TF) (#16281)



* added type hints for mbart tensorflow tf implementation

* Adding missing type hints for mBART model 

Tensorflow Implementation model added with missing type hints

* Missing Type hints - correction

For TF model

* Code fixup using make quality tests

* Hint types - typo error

* make fix-copies and make fixup

* type hints

* updated files
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 935330dd
......@@ -301,7 +301,13 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
layer_head_mask: tf.Tensor,
training: Optional[bool] = False,
):
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
......@@ -370,14 +376,14 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
def call(
self,
hidden_states,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
......
......@@ -32,6 +32,7 @@ from ...modeling_tf_outputs import (
from ...modeling_tf_utils import (
DUMMY_INPUTS,
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
......@@ -299,7 +300,13 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
layer_head_mask: tf.Tensor,
training: Optional[bool] = False,
):
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
......@@ -367,14 +374,14 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
def call(
self,
hidden_states,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
......@@ -669,16 +676,16 @@ class TFMBartEncoder(tf.keras.layers.Layer):
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
inputs_embeds: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
......@@ -828,21 +835,23 @@ class TFMBartDecoder(tf.keras.layers.Layer):
@unpack_inputs
def call(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_ids: TFModelInputType = None,
inputs_embeds: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[
TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
]:
r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
......@@ -1049,24 +1058,24 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
@unpack_inputs
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
input_ids: TFModelInputType = None,
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[tf.Tensor] = None,
decoder_inputs_embeds: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs
):
) -> Union[TFSeq2SeqModelOutput, tf.Tensor]:
if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False
......@@ -1157,24 +1166,24 @@ class TFMBartModel(TFMBartPreTrainedModel):
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
input_ids: TFModelInputType = None,
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[tf.Tensor] = None,
decoder_inputs_embeds: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs
):
) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
outputs = self.model(
input_ids=input_ids,
......@@ -1261,25 +1270,25 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
@add_end_docstrings(MBART_GENERATION_EXAMPLE)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
input_ids: TFModelInputType = None,
attention_mask: Optional[tf.Tensor] = None,
decoder_input_ids: Optional[tf.Tensor] = None,
decoder_attention_mask: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
decoder_head_mask: Optional[tf.Tensor] = None,
cross_attn_head_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
past_key_values: [Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds: Optional[tf.Tensor] = None,
decoder_inputs_embeds: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
training: Optional[bool] = False,
**kwargs,
):
) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
......
......@@ -341,7 +341,13 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
layer_head_mask: tf.Tensor,
training: Optional[bool] = False,
):
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
......@@ -410,14 +416,14 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
def call(
self,
hidden_states,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment