"...git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "18e2b23ad081f705265b467300d1a88cd02ce1f5"
Unverified Commit 9c5ae87f authored by karthikrangasai's avatar karthikrangasai Committed by GitHub
Browse files

Type hint complete Albert model file. (#16682)



* Type hint complete Albert model file.

* Update typing.

* Update src/transformers/models/albert/modeling_albert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 2bf95e2b
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
...@@ -198,7 +198,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -198,7 +198,7 @@ class AlbertEmbeddings(nn.Module):
Construct the embeddings from word, position and token_type embeddings. Construct the embeddings from word, position and token_type embeddings.
""" """
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
...@@ -221,8 +221,13 @@ class AlbertEmbeddings(nn.Module): ...@@ -221,8 +221,13 @@ class AlbertEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -258,7 +263,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -258,7 +263,7 @@ class AlbertEmbeddings(nn.Module):
class AlbertAttention(nn.Module): class AlbertAttention(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -287,12 +292,12 @@ class AlbertAttention(nn.Module): ...@@ -287,12 +292,12 @@ class AlbertAttention(nn.Module):
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def prune_heads(self, heads): def prune_heads(self, heads: List[int]) -> None:
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(
...@@ -310,7 +315,13 @@ class AlbertAttention(nn.Module): ...@@ -310,7 +315,13 @@ class AlbertAttention(nn.Module):
self.all_head_size = self.attention_head_size * self.num_attention_heads self.all_head_size = self.attention_head_size * self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states) mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states) mixed_value_layer = self.value(hidden_states)
...@@ -364,7 +375,7 @@ class AlbertAttention(nn.Module): ...@@ -364,7 +375,7 @@ class AlbertAttention(nn.Module):
class AlbertLayer(nn.Module): class AlbertLayer(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -378,8 +389,13 @@ class AlbertLayer(nn.Module): ...@@ -378,8 +389,13 @@ class AlbertLayer(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward( def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False self,
): hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = apply_chunking_to_forward( ffn_output = apply_chunking_to_forward(
...@@ -392,7 +408,7 @@ class AlbertLayer(nn.Module): ...@@ -392,7 +408,7 @@ class AlbertLayer(nn.Module):
return (hidden_states,) + attention_output[1:] # add attentions if we output them return (hidden_states,) + attention_output[1:] # add attentions if we output them
def ff_chunk(self, attention_output): def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
ffn_output = self.ffn(attention_output) ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output) ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output) ffn_output = self.ffn_output(ffn_output)
...@@ -400,14 +416,19 @@ class AlbertLayer(nn.Module): ...@@ -400,14 +416,19 @@ class AlbertLayer(nn.Module):
class AlbertLayerGroup(nn.Module): class AlbertLayerGroup(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def forward( def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False self,
): hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
layer_hidden_states = () layer_hidden_states = ()
layer_attentions = () layer_attentions = ()
...@@ -430,7 +451,7 @@ class AlbertLayerGroup(nn.Module): ...@@ -430,7 +451,7 @@ class AlbertLayerGroup(nn.Module):
class AlbertTransformer(nn.Module): class AlbertTransformer(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -439,13 +460,13 @@ class AlbertTransformer(nn.Module): ...@@ -439,13 +460,13 @@ class AlbertTransformer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[BaseModelOutput, Tuple]:
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_hidden_states = (hidden_states,) if output_hidden_states else None all_hidden_states = (hidden_states,) if output_hidden_states else None
...@@ -619,7 +640,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -619,7 +640,7 @@ class AlbertModel(AlbertPreTrainedModel):
config_class = AlbertConfig config_class = AlbertConfig
base_model_prefix = "albert" base_model_prefix = "albert"
def __init__(self, config, add_pooling_layer=True): def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -635,13 +656,13 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -635,13 +656,13 @@ class AlbertModel(AlbertPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value: nn.Embedding) -> None:
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
""" """
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
...@@ -667,16 +688,16 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -667,16 +688,16 @@ class AlbertModel(AlbertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[None] = None,
output_hidden_states=None, output_hidden_states: Optional[None] = None,
return_dict=None, return_dict: Optional[None] = None,
): ) -> Union[BaseModelOutputWithPooling, Tuple]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -745,7 +766,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -745,7 +766,7 @@ class AlbertModel(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForPreTraining(AlbertPreTrainedModel): class AlbertForPreTraining(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.albert = AlbertModel(config) self.albert = AlbertModel(config)
...@@ -755,31 +776,31 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -755,31 +776,31 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self) -> nn.Linear:
return self.predictions.decoder return self.predictions.decoder
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings self.predictions.decoder = new_embeddings
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
sentence_order_label=None, sentence_order_label: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[AlbertForPreTrainingOutput, Tuple]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -848,7 +869,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -848,7 +869,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
class AlbertMLMHead(nn.Module): class AlbertMLMHead(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
...@@ -858,7 +879,7 @@ class AlbertMLMHead(nn.Module): ...@@ -858,7 +879,7 @@ class AlbertMLMHead(nn.Module):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -868,19 +889,19 @@ class AlbertMLMHead(nn.Module): ...@@ -868,19 +889,19 @@ class AlbertMLMHead(nn.Module):
return prediction_scores return prediction_scores
def _tie_weights(self): def _tie_weights(self) -> None:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module): class AlbertSOPHead(nn.Module):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
self.dropout = nn.Dropout(config.classifier_dropout_prob) self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, pooled_output): def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
dropout_pooled_output = self.dropout(pooled_output) dropout_pooled_output = self.dropout(pooled_output)
logits = self.classifier(dropout_pooled_output) logits = self.classifier(dropout_pooled_output)
return logits return logits
...@@ -903,30 +924,30 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -903,30 +924,30 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self) -> nn.Linear:
return self.predictions.decoder return self.predictions.decoder
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings self.predictions.decoder = new_embeddings
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[MaskedLMOutput, Tuple]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1006,7 +1027,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -1006,7 +1027,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForSequenceClassification(AlbertPreTrainedModel): class AlbertForSequenceClassification(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config self.config = config
...@@ -1029,17 +1050,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -1029,17 +1050,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[SequenceClassifierOutput, Tuple]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1111,7 +1132,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1111,7 +1132,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1139,17 +1160,17 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1139,17 +1160,17 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[TokenClassifierOutput, Tuple]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1201,7 +1222,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1201,7 +1222,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1224,18 +1245,18 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1224,18 +1245,18 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[AlbertForPreTrainingOutput, Tuple]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -1262,7 +1283,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1262,7 +1283,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits: torch.Tensor = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous() start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous()
...@@ -1305,7 +1326,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): ...@@ -1305,7 +1326,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForMultipleChoice(AlbertPreTrainedModel): class AlbertForMultipleChoice(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.albert = AlbertModel(config) self.albert = AlbertModel(config)
...@@ -1324,17 +1345,17 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel): ...@@ -1324,17 +1345,17 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[AlbertForPreTrainingOutput, Tuple]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1368,7 +1389,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel): ...@@ -1368,7 +1389,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits: torch.Tensor = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices) reshaped_logits = logits.view(-1, num_choices)
loss = None loss = None
......
...@@ -203,8 +203,13 @@ class BertEmbeddings(nn.Module): ...@@ -203,8 +203,13 @@ class BertEmbeddings(nn.Module):
) )
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -266,7 +271,7 @@ class BertSelfAttention(nn.Module): ...@@ -266,7 +271,7 @@ class BertSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -182,7 +182,7 @@ class Data2VecTextSelfAttention(nn.Module): ...@@ -182,7 +182,7 @@ class Data2VecTextSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -174,8 +174,13 @@ class ElectraEmbeddings(nn.Module): ...@@ -174,8 +174,13 @@ class ElectraEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -238,7 +243,7 @@ class ElectraSelfAttention(nn.Module): ...@@ -238,7 +243,7 @@ class ElectraSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -154,7 +154,7 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -154,7 +154,7 @@ class LayoutLMSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -212,7 +212,7 @@ class MegatronBertSelfAttention(nn.Module): ...@@ -212,7 +212,7 @@ class MegatronBertSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import math import math
import os import os
import warnings import warnings
from typing import Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -173,8 +174,13 @@ class QDQBertEmbeddings(nn.Module): ...@@ -173,8 +174,13 @@ class QDQBertEmbeddings(nn.Module):
) )
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
......
...@@ -189,8 +189,13 @@ class RealmEmbeddings(nn.Module): ...@@ -189,8 +189,13 @@ class RealmEmbeddings(nn.Module):
) )
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -253,7 +258,7 @@ class RealmSelfAttention(nn.Module): ...@@ -253,7 +258,7 @@ class RealmSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -182,7 +182,7 @@ class RobertaSelfAttention(nn.Module): ...@@ -182,7 +182,7 @@ class RobertaSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -126,7 +126,7 @@ class SplinterSelfAttention(nn.Module): ...@@ -126,7 +126,7 @@ class SplinterSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
...@@ -176,7 +176,7 @@ class XLMRobertaXLSelfAttention(nn.Module): ...@@ -176,7 +176,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x): def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment