Unverified Commit 7ba1d4e5 authored by Joaq's avatar Joaq Committed by GitHub
Browse files

Add type hints for ProphetNet (Pytorch) (#17223)



* added type hints to prophetnet

* reformatted with black

* fix bc black misformatted some parts

* fix imports

* fix imports

* Update src/transformers/models/prophetnet/configuration_prophetnet.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* update OPTIONAL type hint and docstring
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent d6b8e9ce
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" ProphetNet model configuration""" """ ProphetNet model configuration"""
from typing import Callable, Optional, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -105,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -105,32 +106,32 @@ class ProphetNetConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
activation_dropout=0.1, activation_dropout: Optional[float] = 0.1,
activation_function="gelu", activation_function: Optional[Union[str, Callable]] = "gelu",
vocab_size=30522, vocab_size: Optional[int] = 30522,
hidden_size=1024, hidden_size: Optional[int] = 1024,
encoder_ffn_dim=4096, encoder_ffn_dim: Optional[int] = 4096,
num_encoder_layers=12, num_encoder_layers: Optional[int] = 12,
num_encoder_attention_heads=16, num_encoder_attention_heads: Optional[int] = 16,
decoder_ffn_dim=4096, decoder_ffn_dim: Optional[int] = 4096,
num_decoder_layers=12, num_decoder_layers: Optional[int] = 12,
num_decoder_attention_heads=16, num_decoder_attention_heads: Optional[int] = 16,
attention_dropout=0.1, attention_dropout: Optional[float] = 0.1,
dropout=0.1, dropout: Optional[float] = 0.1,
max_position_embeddings=512, max_position_embeddings: Optional[int] = 512,
init_std=0.02, init_std: Optional[float] = 0.02,
is_encoder_decoder=True, is_encoder_decoder: Optional[bool] = True,
add_cross_attention=True, add_cross_attention: Optional[bool] = True,
decoder_start_token_id=0, decoder_start_token_id: Optional[int] = 0,
ngram=2, ngram: Optional[int] = 2,
num_buckets=32, num_buckets: Optional[int] = 32,
relative_max_distance=128, relative_max_distance: Optional[int] = 128,
disable_ngram_loss=False, disable_ngram_loss: Optional[bool] = False,
eps=0.0, eps: Optional[float] = 0.0,
use_cache=True, use_cache: Optional[bool] = True,
pad_token_id=0, pad_token_id: Optional[int] = 0,
bos_token_id=1, bos_token_id: Optional[int] = 1,
eos_token_id=2, eos_token_id: Optional[int] = 2,
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
......
...@@ -345,7 +345,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput): ...@@ -345,7 +345,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output. hidden_size)` is output.
last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`): last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
...@@ -590,7 +590,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding): ...@@ -590,7 +590,7 @@ class ProphetNetPositionalEmbeddings(nn.Embedding):
the forward function. the forward function.
""" """
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig) -> None:
self.max_length = config.max_position_embeddings self.max_length = config.max_position_embeddings
super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
...@@ -1407,7 +1407,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1407,7 +1407,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
embeddings instead of randomly initialized word embeddings. embeddings instead of randomly initialized word embeddings.
""" """
def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None): def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
super().__init__(config) super().__init__(config)
self.ngram = config.ngram self.ngram = config.ngram
...@@ -1769,7 +1769,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1769,7 +1769,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetModel(ProphetNetPreTrainedModel): class ProphetNetModel(ProphetNetPreTrainedModel):
def __init__(self, config): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -2106,7 +2106,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -2106,7 +2106,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetForCausalLM(ProphetNetPreTrainedModel): class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def __init__(self, config): def __init__(self, config: ProphetNetConfig):
# set config for CLM # set config for CLM
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
...@@ -2341,7 +2341,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): ...@@ -2341,7 +2341,7 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
classes. classes.
""" """
def __init__(self, config): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
self.decoder = ProphetNetDecoder(config) self.decoder = ProphetNetDecoder(config)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import collections import collections
import os import os
from typing import List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -111,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer): ...@@ -111,17 +111,17 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
def __init__( def __init__(
self, self,
vocab_file, vocab_file: str,
do_lower_case=True, do_lower_case: Optional[bool] = True,
do_basic_tokenize=True, do_basic_tokenize: Optional[bool] = True,
never_split=None, never_split: Optional[Iterable] = None,
unk_token="[UNK]", unk_token: Optional[str] = "[UNK]",
sep_token="[SEP]", sep_token: Optional[str] = "[SEP]",
x_sep_token="[X_SEP]", x_sep_token: Optional[str] = "[X_SEP]",
pad_token="[PAD]", pad_token: Optional[str] = "[PAD]",
mask_token="[MASK]", mask_token: Optional[str] = "[MASK]",
tokenize_chinese_chars=True, tokenize_chinese_chars: Optional[bool] = True,
strip_accents=None, strip_accents: Optional[bool] = None,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -177,21 +177,24 @@ class ProphetNetTokenizer(PreTrainedTokenizer): ...@@ -177,21 +177,24 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token)) return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index: int):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens: str):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: Optional[bool] = False,
) -> List[int]: ) -> List[int]:
""" """
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment