"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4d541f516fade15d39b35065155239fd3bf0299a"
Unverified Commit d50f62f2 authored by Robot Jelly's avatar Robot Jelly Committed by GitHub
Browse files

added type hints for BART model (#16270)



* added type hints for BART model

* make fixup, adding imports to copied files

* Adding some missing types to cookiecutter

* Adding some missing types to cookiecutter

* Adding some missing types to cookiecutter
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 460f36d3
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import math import math
import random import random
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -297,11 +297,11 @@ class BartEncoderLayer(nn.Module): ...@@ -297,11 +297,11 @@ class BartEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -384,7 +384,7 @@ class BartDecoderLayer(nn.Module): ...@@ -384,7 +384,7 @@ class BartDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -478,7 +478,7 @@ class BartClassificationHead(nn.Module): ...@@ -478,7 +478,7 @@ class BartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
...@@ -728,14 +728,14 @@ class BartEncoder(BartPretrainedModel): ...@@ -728,14 +728,14 @@ class BartEncoder(BartPretrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -917,19 +917,19 @@ class BartDecoder(BartPretrainedModel): ...@@ -917,19 +917,19 @@ class BartDecoder(BartPretrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -1172,22 +1172,22 @@ class BartModel(BartPretrainedModel): ...@@ -1172,22 +1172,22 @@ class BartModel(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqModelOutput]:
# different to other models, Bart automatically creates decoder_input_ids from # different to other models, Bart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided # input_ids if no decoder_input_ids are provided
...@@ -1306,23 +1306,23 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1306,23 +1306,23 @@ class BartForConditionalGeneration(BartPretrainedModel):
@add_end_docstrings(BART_GENERATION_EXAMPLE) @add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
...@@ -1454,22 +1454,22 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1454,22 +1454,22 @@ class BartForSequenceClassification(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1580,23 +1580,23 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1580,23 +1580,23 @@ class BartForQuestionAnswering(BartPretrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -1721,20 +1721,20 @@ class BartForCausalLM(BartPretrainedModel): ...@@ -1721,20 +1721,20 @@ class BartForCausalLM(BartPretrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -39,6 +40,7 @@ from ...modeling_tf_outputs import ( ...@@ -39,6 +40,7 @@ from ...modeling_tf_outputs import (
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings, TFWrappedEmbeddings,
...@@ -170,7 +172,7 @@ class TFBartAttention(tf.keras.layers.Layer): ...@@ -170,7 +172,7 @@ class TFBartAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -297,7 +299,13 @@ class TFBartEncoderLayer(tf.keras.layers.Layer): ...@@ -297,7 +299,13 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -365,14 +373,14 @@ class TFBartDecoderLayer(tf.keras.layers.Layer): ...@@ -365,14 +373,14 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:
...@@ -663,16 +671,16 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -663,16 +671,16 @@ class TFBartEncoder(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
""" """
Args: Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
...@@ -813,21 +821,21 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -813,21 +821,21 @@ class TFBartDecoder(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r""" r"""
Args: Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
...@@ -1030,24 +1038,24 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1030,24 +1038,24 @@ class TFBartMainLayer(tf.keras.layers.Layer):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs **kwargs
): ) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
use_cache = False use_cache = False
...@@ -1143,24 +1151,24 @@ class TFBartModel(TFBartPretrainedModel): ...@@ -1143,24 +1151,24 @@ class TFBartModel(TFBartPretrainedModel):
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs **kwargs
): ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
...@@ -1248,25 +1256,25 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1248,25 +1256,25 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[TFBaseModelOutput] = None, encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -1573,7 +1573,7 @@ class BigBirdPegasusClassificationHead(nn.Module): ...@@ -1573,7 +1573,7 @@ class BigBirdPegasusClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
...@@ -2367,22 +2367,22 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): ...@@ -2367,22 +2367,22 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqModelOutput]:
# different to other models, BigBirdPegasus automatically creates decoder_input_ids from # different to other models, BigBirdPegasus automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided # input_ids if no decoder_input_ids are provided
...@@ -2503,23 +2503,23 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2503,23 +2503,23 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE) @add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
...@@ -2652,22 +2652,22 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2652,22 +2652,22 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -2779,23 +2779,23 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): ...@@ -2779,23 +2779,23 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import os import os
import random import random
import warnings import warnings
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -1440,20 +1440,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): ...@@ -1440,20 +1440,20 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -173,7 +173,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): ...@@ -173,7 +173,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -288,11 +288,11 @@ class BlenderbotSmallEncoderLayer(nn.Module): ...@@ -288,11 +288,11 @@ class BlenderbotSmallEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -376,7 +376,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): ...@@ -376,7 +376,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -1411,20 +1411,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): ...@@ -1411,20 +1411,20 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -172,7 +173,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): ...@@ -172,7 +173,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -300,7 +301,13 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer): ...@@ -300,7 +301,13 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -369,14 +376,14 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer): ...@@ -369,14 +376,14 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:
......
...@@ -754,7 +754,7 @@ class TFHubertAttention(tf.keras.layers.Layer): ...@@ -754,7 +754,7 @@ class TFHubertAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -305,11 +305,11 @@ class MarianEncoderLayer(nn.Module): ...@@ -305,11 +305,11 @@ class MarianEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -393,7 +393,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -393,7 +393,7 @@ class MarianDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -1573,20 +1573,20 @@ class MarianForCausalLM(MarianPreTrainedModel): ...@@ -1573,20 +1573,20 @@ class MarianForCausalLM(MarianPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -212,7 +212,7 @@ class TFMarianAttention(tf.keras.layers.Layer): ...@@ -212,7 +212,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -340,7 +340,13 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer): ...@@ -340,7 +340,13 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]],
layer_head_mask: Optional[tf.Tensor],
training: Optional[bool] = False,
) -> tf.Tensor:
""" """
Args: Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -409,14 +415,14 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer): ...@@ -409,14 +415,14 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
cross_attn_layer_head_mask: Optional[tf.Tensor] = None, cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None, past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
""" """
Args: Args:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -485,7 +485,7 @@ class MBartClassificationHead(nn.Module): ...@@ -485,7 +485,7 @@ class MBartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
...@@ -1445,22 +1445,22 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1445,22 +1445,22 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1572,23 +1572,23 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1572,23 +1572,23 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.Tensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -1715,20 +1715,20 @@ class MBartForCausalLM(MBartPreTrainedModel): ...@@ -1715,20 +1715,20 @@ class MBartForCausalLM(MBartPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -172,7 +172,7 @@ class TFMBartAttention(tf.keras.layers.Layer): ...@@ -172,7 +172,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -1543,20 +1543,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel): ...@@ -1543,20 +1543,20 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -213,7 +213,7 @@ class TFPegasusAttention(tf.keras.layers.Layer): ...@@ -213,7 +213,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import copy import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -302,11 +302,11 @@ class PLBartEncoderLayer(nn.Module): ...@@ -302,11 +302,11 @@ class PLBartEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.FloatTensor,
attention_mask: torch.Tensor, attention_mask: torch.FloatTensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.FloatTensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -390,7 +390,7 @@ class PLBartDecoderLayer(nn.Module): ...@@ -390,7 +390,7 @@ class PLBartDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -485,7 +485,7 @@ class PLBartClassificationHead(nn.Module): ...@@ -485,7 +485,7 @@ class PLBartClassificationHead(nn.Module):
self.dropout = nn.Dropout(p=pooler_dropout) self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes) self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states) hidden_states = torch.tanh(hidden_states)
...@@ -699,14 +699,14 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -699,14 +699,14 @@ class PLBartEncoder(PLBartPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -889,19 +889,19 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -889,19 +889,19 @@ class PLBartDecoder(PLBartPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -1416,22 +1416,22 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ...@@ -1416,22 +1416,22 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1562,20 +1562,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel): ...@@ -1562,20 +1562,20 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -274,7 +274,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer): ...@@ -274,7 +274,7 @@ class TFSpeech2TextAttention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -783,7 +783,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer): ...@@ -783,7 +783,7 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None, layer_head_mask: Optional[tf.Tensor] = None,
training=False, training: Optional[bool] = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
......
...@@ -25,7 +25,7 @@ import math ...@@ -25,7 +25,7 @@ import math
import os import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, List
import datasets import datasets
from datasets import load_dataset from datasets import load_dataset
......
...@@ -1571,7 +1571,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -1571,7 +1571,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
import math import math
import copy import copy
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple, List, Union
import torch import torch
from torch import nn from torch import nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment