Unverified Commit 45360e1a authored by Robot Jelly's avatar Robot Jelly Committed by GitHub
Browse files

type hints for pytorch models (#17064)

* type hints for pytorch models

* fixed import error

* fixed some errors
parent db377a0b
......@@ -19,7 +19,7 @@ import copy
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -253,11 +253,11 @@ class CanineEmbeddings(nn.Module):
def forward(
self,
input_ids=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
):
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -356,7 +356,11 @@ class ConvProjection(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, inputs, final_seq_char_positions=None):
def forward(
self,
inputs: torch.Tensor,
final_seq_char_positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
# we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
inputs = torch.transpose(inputs, 1, 2)
......@@ -419,12 +423,12 @@ class CanineSelfAttention(nn.Module):
def forward(
self,
from_tensor,
to_tensor,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
from_tensor: torch.Tensor,
to_tensor: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(from_tensor)
# If this is instantiated as a cross-attention module, the keys
......@@ -496,7 +500,9 @@ class CanineSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(
self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
......@@ -574,11 +580,11 @@ class CanineAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
hidden_states: Tuple[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
if not self.local:
self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self_outputs[0]
......@@ -656,7 +662,7 @@ class CanineIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
......@@ -669,7 +675,7 @@ class CanineOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
......@@ -706,11 +712,11 @@ class CanineLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
hidden_states: Tuple[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
......@@ -767,13 +773,13 @@ class CanineEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: Tuple[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -822,7 +828,7 @@ class CaninePooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
......@@ -841,7 +847,7 @@ class CaninePredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
......@@ -862,7 +868,7 @@ class CanineLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
......@@ -873,7 +879,10 @@ class CanineOnlyMLMHead(nn.Module):
super().__init__()
self.predictions = CanineLMPredictionHead(config)
def forward(self, sequence_output):
def forward(
self,
sequence_output: Tuple[torch.Tensor],
) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
......@@ -1093,16 +1102,16 @@ class CanineModel(CaninePreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CanineModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1275,17 +1284,17 @@ class CanineForSequenceClassification(CaninePreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1372,17 +1381,17 @@ class CanineForMultipleChoice(CaninePreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -1465,17 +1474,17 @@ class CanineForTokenClassification(CaninePreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -1543,18 +1552,18 @@ class CanineForQuestionAnswering(CaninePreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
......@@ -201,7 +201,13 @@ class ConvBertEmbeddings(nn.Module):
persistent=False,
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.LongTensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -287,7 +293,7 @@ class SeparableConv1D(nn.Module):
self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.depthwise(hidden_states)
x = self.pointwise(x)
x += self.bias
......@@ -341,12 +347,12 @@ class ConvBertSelfAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
batch_size = hidden_states.size(0)
# If this is instantiated as a cross-attention module, the keys
......@@ -426,7 +432,7 @@ class ConvBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
......@@ -460,12 +466,12 @@ class ConvBertAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_outputs = self.self(
hidden_states,
attention_mask,
......@@ -489,7 +495,7 @@ class GroupedLinearLayer(nn.Module):
self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim))
self.bias = nn.Parameter(torch.empty(output_size))
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = list(hidden_states.size())[0]
x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
x = x.permute(1, 0, 2)
......@@ -514,7 +520,7 @@ class ConvBertIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
......@@ -532,7 +538,7 @@ class ConvBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
......@@ -556,13 +562,13 @@ class ConvBertLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
......@@ -608,15 +614,15 @@ class ConvBertEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
......@@ -684,7 +690,7 @@ class ConvBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
......@@ -795,16 +801,16 @@ class ConvBertModel(ConvBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -864,7 +870,7 @@ class ConvBertGeneratorPredictions(nn.Module):
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
def forward(self, generator_hidden_states):
def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(generator_hidden_states)
hidden_states = get_activation("gelu")(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
......@@ -966,7 +972,7 @@ class ConvBertClassificationHead(nn.Module):
self.config = config
def forward(self, hidden_states, **kwargs):
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
......
......@@ -15,6 +15,8 @@
""" PyTorch ConvNext model."""
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
......@@ -78,7 +80,7 @@ class ConvNextDropPath(nn.Module):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
......@@ -98,7 +100,7 @@ class ConvNextLayerNorm(nn.Module):
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
......@@ -121,7 +123,7 @@ class ConvNextEmbeddings(nn.Module):
)
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
def forward(self, pixel_values):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings)
return embeddings
......@@ -155,7 +157,7 @@ class ConvNextLayer(nn.Module):
)
self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
input = hidden_states
x = self.dwconv(hidden_states)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
......@@ -197,7 +199,7 @@ class ConvNextStage(nn.Module):
*[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
hidden_states = self.downsampling_layer(hidden_states)
hidden_states = self.layers(hidden_states)
return hidden_states
......@@ -224,7 +226,12 @@ class ConvNextEncoder(nn.Module):
cur += config.depths[i]
prev_chs = out_chs
def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
def forward(
self,
hidden_states: torch.FloatTensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.stages):
......@@ -325,7 +332,12 @@ class ConvNextModel(ConvNextPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: torch.FloatTensor = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -387,7 +399,13 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: torch.FloatTensor = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
......@@ -17,7 +17,7 @@
import math
import os
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -278,15 +278,15 @@ class DecisionTransformerGPT2Attention(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
......@@ -340,7 +340,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
......@@ -369,15 +369,15 @@ class DecisionTransformerGPT2Block(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
......@@ -510,20 +510,20 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......
......@@ -15,7 +15,7 @@
""" Classes to support Encoder-Decoder architectures"""
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from torch import nn
......@@ -430,21 +430,21 @@ class EncoderDecoderModel(PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
) -> Union[Tuple, Seq2SeqLMOutput]:
r"""
Returns:
......
......@@ -80,7 +80,7 @@ class GLPNDropPath(nn.Module):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
......
......@@ -18,7 +18,7 @@
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -289,15 +289,15 @@ class GPT2Attention(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
......@@ -350,7 +350,7 @@ class GPT2MLP(nn.Module):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
......@@ -376,15 +376,15 @@ class GPT2Block(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
......@@ -742,20 +742,20 @@ class GPT2Model(GPT2PreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1020,21 +1020,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -1189,22 +1189,22 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
mc_token_ids=None,
labels=None,
mc_labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
mc_labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
......@@ -1352,19 +1352,19 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1488,19 +1488,19 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
# fmt: on
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
......@@ -14,7 +14,7 @@
# limitations under the License.
""" PyTorch GPT-J model."""
from typing import Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -191,13 +191,16 @@ class GPTJAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
......@@ -271,7 +274,7 @@ class GPTJMLP(nn.Module):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
......@@ -289,13 +292,13 @@ class GPTJBlock(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
......@@ -533,18 +536,18 @@ class GPTJModel(GPTJPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -787,19 +790,19 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -911,19 +914,19 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1038,18 +1041,18 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
......@@ -702,11 +702,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
......@@ -764,7 +764,7 @@ class MaskFormerSwinSelfOutput(nn.Module):
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -797,7 +797,13 @@ class MaskFormerSwinAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
......@@ -814,7 +820,7 @@ class MaskFormerSwinIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
......@@ -827,7 +833,7 @@ class MaskFormerSwinOutput(nn.Module):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
......
......@@ -20,7 +20,7 @@ import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -154,8 +154,13 @@ class MegatronBertEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -319,7 +324,7 @@ class MegatronBertSelfOutput(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual):
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return residual + hidden_states
......@@ -354,14 +359,14 @@ class MegatronBertAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
ln_outputs = self.ln(hidden_states)
self_outputs = self.self(
ln_outputs,
......@@ -400,7 +405,7 @@ class MegatronBertOutput(nn.Module):
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return input_tensor + hidden_states
......@@ -425,14 +430,14 @@ class MegatronBertLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
......@@ -507,17 +512,17 @@ class MegatronBertEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
......@@ -873,20 +878,20 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......@@ -1022,18 +1027,18 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MegatronBertForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1133,21 +1138,21 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
......@@ -1287,19 +1292,19 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1379,18 +1384,18 @@ class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
):
) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
......@@ -1489,17 +1494,17 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......@@ -1588,17 +1593,17 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
......@@ -1684,17 +1689,17 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
......@@ -1765,18 +1770,18 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
......@@ -164,7 +164,7 @@ class NoNorm(nn.Module):
self.bias = nn.Parameter(torch.zeros(feat_size))
self.weight = nn.Parameter(torch.ones(feat_size))
def forward(self, input_tensor):
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return input_tensor * self.weight + self.bias
......@@ -194,7 +194,13 @@ class MobileBertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
......@@ -260,13 +266,13 @@ class MobileBertSelfAttention(nn.Module):
def forward(
self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None,
head_mask=None,
output_attentions=None,
):
query_tensor: torch.Tensor,
key_tensor: torch.Tensor,
value_tensor: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(query_tensor)
mixed_key_layer = self.key(key_tensor)
mixed_value_layer = self.value(value_tensor)
......@@ -306,7 +312,7 @@ class MobileBertSelfOutput(nn.Module):
if not self.use_bottleneck:
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual_tensor):
def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
if not self.use_bottleneck:
layer_outputs = self.dropout(layer_outputs)
......@@ -341,14 +347,14 @@ class MobileBertAttention(nn.Module):
def forward(
self,
query_tensor,
key_tensor,
value_tensor,
layer_input,
attention_mask=None,
head_mask=None,
output_attentions=None,
):
query_tensor: torch.Tensor,
key_tensor: torch.Tensor,
value_tensor: torch.Tensor,
layer_input: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
query_tensor,
key_tensor,
......@@ -373,7 +379,7 @@ class MobileBertIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
......@@ -386,7 +392,7 @@ class OutputBottleneck(nn.Module):
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual_tensor):
def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
layer_outputs = self.dropout(layer_outputs)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
......@@ -404,7 +410,9 @@ class MobileBertOutput(nn.Module):
else:
self.bottleneck = OutputBottleneck(config)
def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2):
def forward(
self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
) -> torch.Tensor:
layer_output = self.dense(intermediate_states)
if not self.use_bottleneck:
layer_output = self.dropout(layer_output)
......@@ -421,7 +429,7 @@ class BottleneckLayer(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
layer_input = self.dense(hidden_states)
layer_input = self.LayerNorm(layer_input)
return layer_input
......@@ -436,7 +444,7 @@ class Bottleneck(nn.Module):
if self.key_query_shared_bottleneck:
self.attention = BottleneckLayer(config)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
# This method can return three different tuples of values. These different values make use of bottlenecks,
# which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
# usage. These linear layer have weights that are learned during training.
......@@ -469,7 +477,7 @@ class FFNOutput(nn.Module):
self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, residual_tensor):
def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs
......@@ -481,7 +489,7 @@ class FFNLayer(nn.Module):
self.intermediate = MobileBertIntermediate(config)
self.output = FFNOutput(config)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(hidden_states)
layer_outputs = self.output(intermediate_output, hidden_states)
return layer_outputs
......@@ -503,11 +511,11 @@ class MobileBertLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=None,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
) -> Tuple[torch.Tensor]:
if self.use_bottleneck:
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
else:
......@@ -557,13 +565,13 @@ class MobileBertEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
......@@ -599,7 +607,7 @@ class MobileBertPooler(nn.Module):
if self.do_activate:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
......@@ -621,7 +629,7 @@ class MobileBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act
self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
......@@ -640,7 +648,7 @@ class MobileBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.decoder.bias
......@@ -652,7 +660,7 @@ class MobileBertOnlyMLMHead(nn.Module):
super().__init__()
self.predictions = MobileBertLMPredictionHead(config)
def forward(self, sequence_output):
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores
......@@ -663,7 +671,7 @@ class MobileBertPreTrainingHeads(nn.Module):
self.predictions = MobileBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
......@@ -841,16 +849,16 @@ class MobileBertModel(MobileBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_hidden_states=None,
output_attentions=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -943,18 +951,18 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[torch.FloatTensor] = None,
return_dict: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, MobileBertForPreTrainingOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1059,17 +1067,17 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
......@@ -1115,7 +1123,7 @@ class MobileBertOnlyNSPHead(nn.Module):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
......@@ -1138,18 +1146,18 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
) -> Union[Tuple, NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
......
......@@ -19,7 +19,7 @@ import math
from dataclasses import dataclass
from functools import reduce
from operator import __add__
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
import torch
......@@ -177,7 +177,7 @@ class PerceiverEmbeddings(nn.Module):
super().__init__()
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
def forward(self, batch_size):
def forward(self, batch_size: int):
return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
......@@ -232,13 +232,13 @@ class PerceiverSelfAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
inputs=None,
inputs_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
hidden_states = self.layernorm1(hidden_states)
inputs = self.layernorm2(inputs)
......@@ -301,7 +301,7 @@ class PerceiverSelfOutput(nn.Module):
super().__init__()
self.dense = nn.Linear(input_channels, output_channels)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
return hidden_states
......@@ -377,13 +377,13 @@ class PerceiverAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
inputs=None,
inputs_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
......@@ -418,7 +418,7 @@ class PerceiverMLP(nn.Module):
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(widening_factor * input_size, input_size)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense2(hidden_states)
......@@ -456,13 +456,13 @@ class PerceiverLayer(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
inputs=None,
inputs_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_outputs = self.attention(
hidden_states,
attention_mask,
......@@ -543,15 +543,15 @@ class PerceiverEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
inputs=None,
inputs_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs: Optional[torch.FloatTensor] = None,
inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
......@@ -754,14 +754,14 @@ class PerceiverModel(PerceiverPreTrainedModel):
@replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
inputs,
attention_mask=None,
subsampled_output_points=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
inputs: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PerceiverModelOutput]:
r"""
Returns:
......@@ -1871,7 +1871,7 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
self,
inputs: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
subsampled_output_points: Optional[Dict[str, torch.tensor]] = None,
subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -2020,7 +2020,9 @@ class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
return None
def forward(self, query, z, query_mask=None):
def forward(
self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
# (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
z = torch.mean(z, dim=1)
# (batch_size, d_latents) -> (batch_size, config.num_labels)
......@@ -2044,11 +2046,11 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none".
output_index_dims (`int`, *optional*):
The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
num_channels (`int`, *optional*):
num_channels (`int`, *optional*, defaults to 128):
The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
qk_channels (`int`, *optional*):
The number of channels of the queries and keys in the cross-attention layer.
v_channels (`int`, *optional*, defaults to 128):
v_channels (`int`, *optional*):
The number of channels of the values in the cross-attention layer.
num_heads (`int`, *optional*, defaults to 1):
The number of attention heads in the cross-attention layer.
......@@ -2066,23 +2068,23 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
def __init__(
self,
config,
output_num_channels,
position_encoding_type="trainable",
config: PerceiverConfig,
output_num_channels: int,
position_encoding_type: Optional[str] = "trainable",
# The following 2 arguments are ignored if position_encoding_type == 'none':
output_index_dims=None,
num_channels=128,
subsampled_index_dims=None,
qk_channels=None,
v_channels=None,
num_heads=1,
widening_factor=1,
use_query_residual=False,
concat_preprocessed_input=False,
final_project=True,
position_encoding_only=False,
output_index_dims: Optional[int] = None,
num_channels: Optional[int] = 128,
subsampled_index_dims: Optional[int] = None,
qk_channels: Optional[int] = None,
v_channels: Optional[int] = None,
num_heads: Optional[int] = 1,
widening_factor: Optional[int] = 1,
use_query_residual: Optional[bool] = False,
concat_preprocessed_input: Optional[bool] = False,
final_project: Optional[bool] = True,
position_encoding_only: Optional[bool] = False,
**position_encoding_kwargs,
):
) -> None:
super().__init__()
self.output_num_channels = output_num_channels
......@@ -2183,7 +2185,13 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
return pos_emb
def forward(self, query, z, query_mask=None, output_attentions=False):
def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
# Cross-attention decoding.
# key, value: B x N x K; query: B x M x K
# Attention maps -> B x N x M
......@@ -2239,7 +2247,13 @@ class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
)
def forward(self, query, z, query_mask=None, output_attentions=False):
def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
# B x 1 x num_classes -> B x num_classes
......@@ -2268,7 +2282,13 @@ class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
raise ValueError("FlowDecoder doesn't support subsampling yet.")
return inputs
def forward(self, query, z, query_mask=None, output_attentions=False):
def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
preds = decoder_outputs.logits
# Output flow and rescale.
......@@ -2291,7 +2311,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none".
"""
def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs):
def __init__(
self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs
) -> None:
super().__init__()
if len(output_shape) != 4: # B, T, H, W
raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
......@@ -2318,7 +2340,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
subsampled_points=subsampled_points,
)
def forward(self, query, z, query_mask=None):
def forward(
self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z)
logits = decoder_outputs.logits
......@@ -2378,14 +2402,14 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
def __init__(
self,
config,
modalities,
num_outputs,
output_num_channels,
min_padding_size=2,
subsampled_index_dims=None,
config: PerceiverConfig,
modalities: Dict[str, PerceiverAbstractDecoder],
num_outputs: int,
output_num_channels: int,
min_padding_size: Optional[int] = 2,
subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None,
**decoder_kwargs
):
) -> None:
super().__init__()
self.modalities = nn.ModuleDict(modalities)
self.subsampled_index_dims = subsampled_index_dims
......@@ -2447,7 +2471,13 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
[embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
)
def forward(self, query, z, query_mask=None, output_attentions=False):
def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
# B x 1 x num_classes -> B x num_classes
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
......@@ -2680,7 +2710,7 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
def output_size(self, *args, **kwargs) -> int:
return self._num_channels
def forward(self, batch_size):
def forward(self, batch_size: int) -> torch.Tensor:
position_embeddings = self.position_embeddings
if batch_size is not None:
......@@ -2741,7 +2771,9 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
return encoding_size
def forward(self, index_dims, batch_size, device, pos=None):
def forward(
self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None
) -> torch.FloatTensor:
pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
fourier_pos_enc = generate_fourier_features(
pos,
......@@ -2771,7 +2803,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
Model configuration.
"""
def __init__(self, config):
def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config = config
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
......@@ -2781,7 +2813,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def num_channels(self) -> int:
return self.config.d_model
def forward(self, inputs):
def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor:
embeddings = self.embeddings(inputs)
seq_length = inputs.shape[1]
......@@ -2800,13 +2832,13 @@ class PerceiverEmbeddingDecoder(nn.Module):
Model configuration.
"""
def __init__(self, config):
def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.bias = nn.Parameter(torch.zeros(self.vocab_size))
def forward(self, hidden_states, embedding_layer):
def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = hidden_states.shape
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
output = output + self.bias
......@@ -2859,7 +2891,7 @@ class PerceiverClassificationPostprocessor(nn.Module):
Number of channels in the input.
"""
def __init__(self, config, in_channels):
def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
super().__init__()
self.classifier = nn.Linear(in_channels, config.num_labels)
......@@ -2881,7 +2913,7 @@ class PerceiverAudioPostprocessor(nn.Module):
Postprocessor type to use. Currently, only "patches" is supported.
"""
def __init__(self, config, in_channels, postproc_type: str = "patches"):
def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
super().__init__()
if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels'
......@@ -2908,7 +2940,7 @@ class PerceiverProjectionPostprocessor(nn.Module):
Number of channels in the output.
"""
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.classifier = nn.Linear(in_channels, out_channels)
......@@ -3155,7 +3187,7 @@ class PerceiverOneHotPreprocessor(AbstractPreprocessor):
Model configuration.
"""
def __init__(self, config):
def __init__(self, config: PerceiverConfig) -> None:
super().__init__()
self.config: PerceiverConfig = config
......
......@@ -18,6 +18,7 @@ RetriBERT model
import math
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
......@@ -85,7 +86,7 @@ RETRIBERT_START_DOCSTRING = r"""
RETRIBERT_START_DOCSTRING,
)
class RetriBertModel(RetriBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: RetriBertConfig) -> None:
super().__init__(config)
self.projection_dim = config.projection_dim
......@@ -173,8 +174,13 @@ class RetriBertModel(RetriBertPreTrainedModel):
return self.project_doc(a_reps)
def forward(
self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
):
self,
input_ids_query: torch.LongTensor,
attention_mask_query: Optional[torch.FloatTensor],
input_ids_doc: torch.LongTensor,
attention_mask_doc: Optional[torch.FloatTensor],
checkpoint_batch_size: int = -1,
) -> torch.FloatTensor:
r"""
Args:
input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
......@@ -112,7 +112,7 @@ class SegformerDropPath(nn.Module):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
......
......@@ -18,7 +18,7 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -272,7 +272,9 @@ class SwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None):
def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size()
......@@ -317,7 +319,7 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values):
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, _, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
......@@ -342,7 +344,7 @@ class SwinPatchMerging(nn.Module):
Normalization layer class.
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
......@@ -357,7 +359,7 @@ class SwinPatchMerging(nn.Module):
return input_feature
def forward(self, input_feature, input_dimensions):
def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
......@@ -438,11 +440,11 @@ class SwinSelfAttention(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
......@@ -499,7 +501,7 @@ class SwinSelfOutput(nn.Module):
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states, input_tensor):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
......@@ -531,7 +533,13 @@ class SwinAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
......@@ -547,7 +555,7 @@ class SwinIntermediate(nn.Module):
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
......@@ -559,7 +567,7 @@ class SwinOutput(nn.Module):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
......@@ -621,7 +629,13 @@ class SwinLayer(nn.Module):
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
......@@ -703,7 +717,13 @@ class SwinStage(nn.Module):
self.pointing = False
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
......@@ -752,13 +772,13 @@ class SwinEncoder(nn.Module):
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
......@@ -920,13 +940,13 @@ class SwinModel(SwinPreTrainedModel):
)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -999,13 +1019,13 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
......@@ -1115,13 +1135,13 @@ class SwinForImageClassification(SwinPreTrainedModel):
)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
......@@ -881,14 +881,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
)
def forward(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TransfoXLModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -1071,15 +1071,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
)
def forward(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_ids: Optional[torch.LongTensor] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TransfoXLLMHeadModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -1215,11 +1215,11 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
input_ids: Optional[torch.LongTensor] = None,
mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
......
......@@ -16,6 +16,7 @@
import math
from collections import OrderedDict
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -81,7 +82,7 @@ class VanDropPath(nn.Module):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)
......@@ -146,7 +147,7 @@ class VanLargeKernelAttentionLayer(nn.Module):
super().__init__()
self.attention = VanLargeKernelAttention(hidden_size)
def forward(self, hidden_state):
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
attention = self.attention(hidden_state)
attended = hidden_state * attention
return attended
......@@ -171,7 +172,7 @@ class VanSpatialAttentionLayer(nn.Module):
self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
def forward(self, hidden_state):
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state
hidden_state = self.pre_projection(hidden_state)
hidden_state = self.attention_layer(hidden_state)
......@@ -189,7 +190,7 @@ class VanLayerScaling(nn.Module):
super().__init__()
self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
def forward(self, hidden_state):
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
# unsqueezing for broadcasting
hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
return hidden_state
......@@ -218,7 +219,7 @@ class VanLayer(nn.Module):
)
self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
def forward(self, hidden_state):
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state
# attention
hidden_state = self.pre_normomalization(hidden_state)
......@@ -269,7 +270,7 @@ class VanStage(nn.Module):
)
self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_state):
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.embeddings(hidden_state)
hidden_state = self.layers(hidden_state)
# rearrange b c h w -> b (h w) c
......@@ -316,7 +317,12 @@ class VanEncoder(nn.Module):
)
)
def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
def forward(
self,
hidden_state: torch.Tensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for _, stage_module in enumerate(self.stages):
......@@ -411,7 +417,12 @@ class VanModel(VanPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: Optional[torch.FloatTensor],
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -463,7 +474,13 @@ class VanForImageClassification(VanPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment