"docs/source/vscode:/vscode.git/clone" did not exist on "ee2482b6f8eca320f1eff112673adae215dbc927"
Unverified Commit 9c5ae87f authored by karthikrangasai's avatar karthikrangasai Committed by GitHub
Browse files

Type hint complete Albert model file. (#16682)



* Type hint complete Albert model file.

* Update typing.

* Update src/transformers/models/albert/modeling_albert.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 2bf95e2b
......@@ -17,7 +17,7 @@
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging import version
......@@ -198,7 +198,7 @@ class AlbertEmbeddings(nn.Module):
Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
......@@ -221,8 +221,13 @@ class AlbertEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -258,7 +263,7 @@ class AlbertEmbeddings(nn.Module):
class AlbertAttention(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
......@@ -287,12 +292,12 @@ class AlbertAttention(nn.Module):
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def prune_heads(self, heads):
def prune_heads(self, heads: List[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
......@@ -310,7 +315,13 @@ class AlbertAttention(nn.Module):
self.all_head_size = self.attention_head_size * self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
......@@ -364,7 +375,7 @@ class AlbertAttention(nn.Module):
class AlbertLayer(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.config = config
......@@ -378,8 +389,13 @@ class AlbertLayer(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = apply_chunking_to_forward(
......@@ -392,7 +408,7 @@ class AlbertLayer(nn.Module):
return (hidden_states,) + attention_output[1:] # add attentions if we output them
def ff_chunk(self, attention_output):
def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
ffn_output = self.ffn(attention_output)
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
......@@ -400,14 +416,19 @@ class AlbertLayer(nn.Module):
class AlbertLayerGroup(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
layer_hidden_states = ()
layer_attentions = ()
......@@ -430,7 +451,7 @@ class AlbertLayerGroup(nn.Module):
class AlbertTransformer(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.config = config
......@@ -439,13 +460,13 @@ class AlbertTransformer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[BaseModelOutput, Tuple]:
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_hidden_states = (hidden_states,) if output_hidden_states else None
......@@ -619,7 +640,7 @@ class AlbertModel(AlbertPreTrainedModel):
config_class = AlbertConfig
base_model_prefix = "albert"
def __init__(self, config, add_pooling_layer=True):
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
super().__init__(config)
self.config = config
......@@ -635,13 +656,13 @@ class AlbertModel(AlbertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
......@@ -667,16 +688,16 @@ class AlbertModel(AlbertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[None] = None,
output_hidden_states: Optional[None] = None,
return_dict: Optional[None] = None,
) -> Union[BaseModelOutputWithPooling, Tuple]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -745,7 +766,7 @@ class AlbertModel(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.albert = AlbertModel(config)
......@@ -755,31 +776,31 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
def get_output_embeddings(self) -> nn.Linear:
return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
sentence_order_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
sentence_order_label: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[AlbertForPreTrainingOutput, Tuple]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -848,7 +869,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
class AlbertMLMHead(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
......@@ -858,7 +879,7 @@ class AlbertMLMHead(nn.Module):
self.activation = ACT2FN[config.hidden_act]
self.decoder.bias = self.bias
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
......@@ -868,19 +889,19 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self):
def _tie_weights(self) -> None:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__()
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, pooled_output):
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
dropout_pooled_output = self.dropout(pooled_output)
logits = self.classifier(dropout_pooled_output)
return logits
......@@ -903,30 +924,30 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
def get_output_embeddings(self) -> nn.Linear:
return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[MaskedLMOutput, Tuple]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1006,7 +1027,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForSequenceClassification(AlbertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
......@@ -1029,17 +1050,17 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[SequenceClassifierOutput, Tuple]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1111,7 +1132,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1139,17 +1160,17 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[TokenClassifierOutput, Tuple]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -1201,7 +1222,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1224,18 +1245,18 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[AlbertForPreTrainingOutput, Tuple]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......@@ -1262,7 +1283,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
logits: torch.Tensor = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
......@@ -1305,7 +1326,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.albert = AlbertModel(config)
......@@ -1324,17 +1345,17 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[AlbertForPreTrainingOutput, Tuple]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -1368,7 +1389,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits: torch.Tensor = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
......
......@@ -203,8 +203,13 @@ class BertEmbeddings(nn.Module):
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -266,7 +271,7 @@ class BertSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -182,7 +182,7 @@ class Data2VecTextSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -174,8 +174,13 @@ class ElectraEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -238,7 +243,7 @@ class ElectraSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -154,7 +154,7 @@ class LayoutLMSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -212,7 +212,7 @@ class MegatronBertSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -19,6 +19,7 @@
import math
import os
import warnings
from typing import Optional
import torch
import torch.utils.checkpoint
......@@ -173,8 +174,13 @@ class QDQBertEmbeddings(nn.Module):
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......
......@@ -189,8 +189,13 @@ class RealmEmbeddings(nn.Module):
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -253,7 +258,7 @@ class RealmSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -182,7 +182,7 @@ class RobertaSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -126,7 +126,7 @@ class SplinterSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
......@@ -176,7 +176,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment