"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "afe5d42d8d1d80af911ed980c2936bfe887078f6"
Unverified Commit d49f8d31 authored by Clémentine Fourrier's avatar Clémentine Fourrier Committed by GitHub
Browse files

Added type hints for Pytorch Marian calls (#16200)



* Added type hinting for forward functions in pytorch marian

* typo correction

* Removed type hints on functions from BART per Suraj Patil request

* fix import pb

* fix typo

* corrected tuple call

* ran black

* after fix-copies
Some optional tags on primitives were removed, past_key_values in MarianForCausalLM changed from Tuple of Tuple to List

* Fixing copies to roformer and pegasus
Co-authored-by: default avatarClementine Fourrier <cfourrie@inria.fr>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent a2379b92
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import copy import copy
import math import math
import random import random
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class MarianSinusoidalPositionalEmbedding(nn.Embedding): class MarianSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight) self.weight = self._init_weight(self.weight)
@staticmethod @staticmethod
def _init_weight(out: nn.Parameter): def _init_weight(out: nn.Parameter) -> nn.Parameter:
""" """
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:] the 2nd half of the vector. [dim // 2:]
...@@ -131,7 +131,7 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding): ...@@ -131,7 +131,7 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding):
return out return out
@torch.no_grad() @torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2] bsz, seq_len = input_ids_shape[:2]
positions = torch.arange( positions = torch.arange(
...@@ -477,7 +477,7 @@ class MarianPreTrainedModel(PreTrainedModel): ...@@ -477,7 +477,7 @@ class MarianPreTrainedModel(PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
std = self.config.init_std std = self.config.init_std
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
...@@ -665,9 +665,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -665,9 +665,7 @@ class MarianEncoder(MarianPreTrainedModel):
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding( self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings, embed_dim, self.padding_idx
embed_dim,
self.padding_idx,
) )
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
...@@ -683,14 +681,14 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -683,14 +681,14 @@ class MarianEncoder(MarianPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -833,9 +831,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -833,9 +831,7 @@ class MarianDecoder(MarianPreTrainedModel):
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx) self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
self.embed_positions = MarianSinusoidalPositionalEmbedding( self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings, config.d_model, self.padding_idx
config.d_model,
self.padding_idx,
) )
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
...@@ -870,19 +866,19 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -870,19 +866,19 @@ class MarianDecoder(MarianPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
...@@ -1082,8 +1078,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1082,8 +1078,7 @@ class MarianDecoder(MarianPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"The bare Marian Model outputting raw hidden-states without any specific head on top.", "The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
MARIAN_START_DOCSTRING,
) )
class MarianModel(MarianPreTrainedModel): class MarianModel(MarianPreTrainedModel):
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
...@@ -1143,7 +1138,7 @@ class MarianModel(MarianPreTrainedModel): ...@@ -1143,7 +1138,7 @@ class MarianModel(MarianPreTrainedModel):
def get_decoder(self): def get_decoder(self):
return self.decoder return self.decoder
def resize_decoder_token_embeddings(self, new_num_tokens): def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
if self.config.share_encoder_decoder_embeddings: if self.config.share_encoder_decoder_embeddings:
raise ValueError( raise ValueError(
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` " "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
...@@ -1171,22 +1166,22 @@ class MarianModel(MarianPreTrainedModel): ...@@ -1171,22 +1166,22 @@ class MarianModel(MarianPreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Seq2SeqModelOutput:
r""" r"""
Returns: Returns:
...@@ -1279,10 +1274,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1279,10 +1274,7 @@ class MarianMTModel(MarianPreTrainedModel):
r"embed_positions", r"embed_positions",
] ]
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__(config) super().__init__(config)
...@@ -1309,7 +1301,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1309,7 +1301,7 @@ class MarianMTModel(MarianPreTrainedModel):
self._resize_final_logits_bias(new_num_tokens) self._resize_final_logits_bias(new_num_tokens)
return new_embeddings return new_embeddings
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_embeddings = self.get_input_embeddings() old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings) self.set_input_embeddings(new_embeddings)
...@@ -1370,7 +1362,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1370,7 +1362,7 @@ class MarianMTModel(MarianPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings: nn.Embedding):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def tie_weights(self): def tie_weights(self):
...@@ -1400,23 +1392,23 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1400,23 +1392,23 @@ class MarianMTModel(MarianPreTrainedModel):
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE) @add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Seq2SeqLMOutput:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
...@@ -1479,16 +1471,16 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1479,16 +1471,16 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids: torch.LongTensor,
past=None, past: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
**kwargs **kwargs,
): ) -> Dict:
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
......
...@@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -109,12 +109,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class PegasusSinusoidalPositionalEmbedding(nn.Embedding): class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight) self.weight = self._init_weight(self.weight)
@staticmethod @staticmethod
def _init_weight(out: nn.Parameter): def _init_weight(out: nn.Parameter) -> nn.Parameter:
""" """
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:] the 2nd half of the vector. [dim // 2:]
...@@ -131,7 +131,7 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding): ...@@ -131,7 +131,7 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
return out return out
@torch.no_grad() @torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2] bsz, seq_len = input_ids_shape[:2]
positions = torch.arange( positions = torch.arange(
......
...@@ -73,12 +73,12 @@ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -73,12 +73,12 @@ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight) self.weight = self._init_weight(self.weight)
@staticmethod @staticmethod
def _init_weight(out: nn.Parameter): def _init_weight(out: nn.Parameter) -> nn.Parameter:
""" """
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:] the 2nd half of the vector. [dim // 2:]
...@@ -95,7 +95,7 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): ...@@ -95,7 +95,7 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
return out return out
@torch.no_grad() @torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2] bsz, seq_len = input_ids_shape[:2]
positions = torch.arange( positions = torch.arange(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment