Unverified Commit 45360e1a authored by Robot Jelly's avatar Robot Jelly Committed by GitHub
Browse files

type hints for pytorch models (#17064)

* type hints for pytorch models

* fixed import error

* fixed some errors
parent db377a0b
...@@ -19,7 +19,7 @@ import copy ...@@ -19,7 +19,7 @@ import copy
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -253,11 +253,11 @@ class CanineEmbeddings(nn.Module): ...@@ -253,11 +253,11 @@ class CanineEmbeddings(nn.Module):
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -356,7 +356,11 @@ class ConvProjection(nn.Module): ...@@ -356,7 +356,11 @@ class ConvProjection(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, inputs, final_seq_char_positions=None): def forward(
self,
inputs: torch.Tensor,
final_seq_char_positions: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final] # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
# we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq] # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
inputs = torch.transpose(inputs, 1, 2) inputs = torch.transpose(inputs, 1, 2)
...@@ -419,12 +423,12 @@ class CanineSelfAttention(nn.Module): ...@@ -419,12 +423,12 @@ class CanineSelfAttention(nn.Module):
def forward( def forward(
self, self,
from_tensor, from_tensor: torch.Tensor,
to_tensor, to_tensor: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(from_tensor) mixed_query_layer = self.query(from_tensor)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -496,7 +500,9 @@ class CanineSelfOutput(nn.Module): ...@@ -496,7 +500,9 @@ class CanineSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(
self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -574,11 +580,11 @@ class CanineAttention(nn.Module): ...@@ -574,11 +580,11 @@ class CanineAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Tuple[torch.FloatTensor],
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
if not self.local: if not self.local:
self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions) self_outputs = self.self(hidden_states, hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self_outputs[0] attention_output = self_outputs[0]
...@@ -656,7 +662,7 @@ class CanineIntermediate(nn.Module): ...@@ -656,7 +662,7 @@ class CanineIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -669,7 +675,7 @@ class CanineOutput(nn.Module): ...@@ -669,7 +675,7 @@ class CanineOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: Tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -706,11 +712,11 @@ class CanineLayer(nn.Module): ...@@ -706,11 +712,11 @@ class CanineLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Tuple[torch.FloatTensor],
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -767,13 +773,13 @@ class CanineEncoder(nn.Module): ...@@ -767,13 +773,13 @@ class CanineEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Tuple[torch.FloatTensor],
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -822,7 +828,7 @@ class CaninePooler(nn.Module): ...@@ -822,7 +828,7 @@ class CaninePooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -841,7 +847,7 @@ class CaninePredictionHeadTransform(nn.Module): ...@@ -841,7 +847,7 @@ class CaninePredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -862,7 +868,7 @@ class CanineLMPredictionHead(nn.Module): ...@@ -862,7 +868,7 @@ class CanineLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states: Tuple[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) hidden_states = self.decoder(hidden_states)
return hidden_states return hidden_states
...@@ -873,7 +879,10 @@ class CanineOnlyMLMHead(nn.Module): ...@@ -873,7 +879,10 @@ class CanineOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = CanineLMPredictionHead(config) self.predictions = CanineLMPredictionHead(config)
def forward(self, sequence_output): def forward(
self,
sequence_output: Tuple[torch.Tensor],
) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
...@@ -1093,16 +1102,16 @@ class CanineModel(CaninePreTrainedModel): ...@@ -1093,16 +1102,16 @@ class CanineModel(CaninePreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CanineModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -1275,17 +1284,17 @@ class CanineForSequenceClassification(CaninePreTrainedModel): ...@@ -1275,17 +1284,17 @@ class CanineForSequenceClassification(CaninePreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1372,17 +1381,17 @@ class CanineForMultipleChoice(CaninePreTrainedModel): ...@@ -1372,17 +1381,17 @@ class CanineForMultipleChoice(CaninePreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1465,17 +1474,17 @@ class CanineForTokenClassification(CaninePreTrainedModel): ...@@ -1465,17 +1474,17 @@ class CanineForTokenClassification(CaninePreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1543,18 +1552,18 @@ class CanineForQuestionAnswering(CaninePreTrainedModel): ...@@ -1543,18 +1552,18 @@ class CanineForQuestionAnswering(CaninePreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -201,7 +201,13 @@ class ConvBertEmbeddings(nn.Module): ...@@ -201,7 +201,13 @@ class ConvBertEmbeddings(nn.Module):
persistent=False, persistent=False,
) )
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.LongTensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -287,7 +293,7 @@ class SeparableConv1D(nn.Module): ...@@ -287,7 +293,7 @@ class SeparableConv1D(nn.Module):
self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range) self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range) self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.depthwise(hidden_states) x = self.depthwise(hidden_states)
x = self.pointwise(x) x = self.pointwise(x)
x += self.bias x += self.bias
...@@ -341,12 +347,12 @@ class ConvBertSelfAttention(nn.Module): ...@@ -341,12 +347,12 @@ class ConvBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
batch_size = hidden_states.size(0) batch_size = hidden_states.size(0)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -426,7 +432,7 @@ class ConvBertSelfOutput(nn.Module): ...@@ -426,7 +432,7 @@ class ConvBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -460,12 +466,12 @@ class ConvBertAttention(nn.Module): ...@@ -460,12 +466,12 @@ class ConvBertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -489,7 +495,7 @@ class GroupedLinearLayer(nn.Module): ...@@ -489,7 +495,7 @@ class GroupedLinearLayer(nn.Module):
self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim)) self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim))
self.bias = nn.Parameter(torch.empty(output_size)) self.bias = nn.Parameter(torch.empty(output_size))
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = list(hidden_states.size())[0] batch_size = list(hidden_states.size())[0]
x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]) x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
x = x.permute(1, 0, 2) x = x.permute(1, 0, 2)
...@@ -514,7 +520,7 @@ class ConvBertIntermediate(nn.Module): ...@@ -514,7 +520,7 @@ class ConvBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -532,7 +538,7 @@ class ConvBertOutput(nn.Module): ...@@ -532,7 +538,7 @@ class ConvBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -556,13 +562,13 @@ class ConvBertLayer(nn.Module): ...@@ -556,13 +562,13 @@ class ConvBertLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor, Optional[torch.FloatTensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -608,15 +614,15 @@ class ConvBertEncoder(nn.Module): ...@@ -608,15 +614,15 @@ class ConvBertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -684,7 +690,7 @@ class ConvBertPredictionHeadTransform(nn.Module): ...@@ -684,7 +690,7 @@ class ConvBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -795,16 +801,16 @@ class ConvBertModel(ConvBertPreTrainedModel): ...@@ -795,16 +801,16 @@ class ConvBertModel(ConvBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -864,7 +870,7 @@ class ConvBertGeneratorPredictions(nn.Module): ...@@ -864,7 +870,7 @@ class ConvBertGeneratorPredictions(nn.Module):
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dense = nn.Linear(config.hidden_size, config.embedding_size) self.dense = nn.Linear(config.hidden_size, config.embedding_size)
def forward(self, generator_hidden_states): def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.dense(generator_hidden_states) hidden_states = self.dense(generator_hidden_states)
hidden_states = get_activation("gelu")(hidden_states) hidden_states = get_activation("gelu")(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -966,7 +972,7 @@ class ConvBertClassificationHead(nn.Module): ...@@ -966,7 +972,7 @@ class ConvBertClassificationHead(nn.Module):
self.config = config self.config = config
def forward(self, hidden_states, **kwargs): def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS]) x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x) x = self.dropout(x)
x = self.dense(x) x = self.dense(x)
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
""" PyTorch ConvNext model.""" """ PyTorch ConvNext model."""
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -78,7 +80,7 @@ class ConvNextDropPath(nn.Module): ...@@ -78,7 +80,7 @@ class ConvNextDropPath(nn.Module):
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -98,7 +100,7 @@ class ConvNextLayerNorm(nn.Module): ...@@ -98,7 +100,7 @@ class ConvNextLayerNorm(nn.Module):
raise NotImplementedError(f"Unsupported data format: {self.data_format}") raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,) self.normalized_shape = (normalized_shape,)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last": if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first": elif self.data_format == "channels_first":
...@@ -121,7 +123,7 @@ class ConvNextEmbeddings(nn.Module): ...@@ -121,7 +123,7 @@ class ConvNextEmbeddings(nn.Module):
) )
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first") self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
def forward(self, pixel_values): def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings) embeddings = self.layernorm(embeddings)
return embeddings return embeddings
...@@ -155,7 +157,7 @@ class ConvNextLayer(nn.Module): ...@@ -155,7 +157,7 @@ class ConvNextLayer(nn.Module):
) )
self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
input = hidden_states input = hidden_states
x = self.dwconv(hidden_states) x = self.dwconv(hidden_states)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
...@@ -197,7 +199,7 @@ class ConvNextStage(nn.Module): ...@@ -197,7 +199,7 @@ class ConvNextStage(nn.Module):
*[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)] *[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
) )
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
hidden_states = self.downsampling_layer(hidden_states) hidden_states = self.downsampling_layer(hidden_states)
hidden_states = self.layers(hidden_states) hidden_states = self.layers(hidden_states)
return hidden_states return hidden_states
...@@ -224,7 +226,12 @@ class ConvNextEncoder(nn.Module): ...@@ -224,7 +226,12 @@ class ConvNextEncoder(nn.Module):
cur += config.depths[i] cur += config.depths[i]
prev_chs = out_chs prev_chs = out_chs
def forward(self, hidden_states, output_hidden_states=False, return_dict=True): def forward(
self,
hidden_states: torch.FloatTensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.stages): for i, layer_module in enumerate(self.stages):
...@@ -325,7 +332,12 @@ class ConvNextModel(ConvNextPreTrainedModel): ...@@ -325,7 +332,12 @@ class ConvNextModel(ConvNextPreTrainedModel):
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: torch.FloatTensor = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
...@@ -387,7 +399,13 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel): ...@@ -387,7 +399,13 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
) )
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: torch.FloatTensor = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -278,15 +278,15 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -278,15 +278,15 @@ class DecisionTransformerGPT2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"): if not hasattr(self, "q_attn"):
raise ValueError( raise ValueError(
...@@ -340,7 +340,7 @@ class DecisionTransformerGPT2MLP(nn.Module): ...@@ -340,7 +340,7 @@ class DecisionTransformerGPT2MLP(nn.Module):
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states): def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states) hidden_states = self.c_proj(hidden_states)
...@@ -369,15 +369,15 @@ class DecisionTransformerGPT2Block(nn.Module): ...@@ -369,15 +369,15 @@ class DecisionTransformerGPT2Block(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(
...@@ -510,20 +510,20 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -510,20 +510,20 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Model.forward
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
""" Classes to support Encoder-Decoder architectures""" """ Classes to support Encoder-Decoder architectures"""
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -430,21 +430,21 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -430,21 +430,21 @@ class EncoderDecoderModel(PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple, Seq2SeqLMOutput]:
r""" r"""
Returns: Returns:
......
...@@ -80,7 +80,7 @@ class GLPNDropPath(nn.Module): ...@@ -80,7 +80,7 @@ class GLPNDropPath(nn.Module):
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -289,15 +289,15 @@ class GPT2Attention(nn.Module): ...@@ -289,15 +289,15 @@ class GPT2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"): if not hasattr(self, "q_attn"):
raise ValueError( raise ValueError(
...@@ -350,7 +350,7 @@ class GPT2MLP(nn.Module): ...@@ -350,7 +350,7 @@ class GPT2MLP(nn.Module):
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states): def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states) hidden_states = self.c_proj(hidden_states)
...@@ -376,15 +376,15 @@ class GPT2Block(nn.Module): ...@@ -376,15 +376,15 @@ class GPT2Block(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(
...@@ -742,20 +742,20 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -742,20 +742,20 @@ class GPT2Model(GPT2PreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -1020,21 +1020,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1020,21 +1020,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -1189,22 +1189,22 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1189,22 +1189,22 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
mc_token_ids=None, mc_token_ids: Optional[torch.LongTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
mc_labels=None, mc_labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
r""" r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
...@@ -1352,19 +1352,19 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1352,19 +1352,19 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1488,19 +1488,19 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel): ...@@ -1488,19 +1488,19 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
# fmt: on # fmt: on
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch GPT-J model.""" """ PyTorch GPT-J model."""
from typing import Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -191,13 +191,16 @@ class GPTJAttention(nn.Module): ...@@ -191,13 +191,16 @@ class GPTJAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[torch.FloatTensor],
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states) query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states) key = self.k_proj(hidden_states)
...@@ -271,7 +274,7 @@ class GPTJMLP(nn.Module): ...@@ -271,7 +274,7 @@ class GPTJMLP(nn.Module):
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states): def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.fc_in(hidden_states) hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states) hidden_states = self.fc_out(hidden_states)
...@@ -289,13 +292,13 @@ class GPTJBlock(nn.Module): ...@@ -289,13 +292,13 @@ class GPTJBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: Optional[torch.FloatTensor],
layer_past=None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(
...@@ -533,18 +536,18 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -533,18 +536,18 @@ class GPTJModel(GPTJPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -787,19 +790,19 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ...@@ -787,19 +790,19 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -911,19 +914,19 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -911,19 +914,19 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1038,18 +1041,18 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel): ...@@ -1038,18 +1041,18 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -702,11 +702,11 @@ class MaskFormerSwinSelfAttention(nn.Module): ...@@ -702,11 +702,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -764,7 +764,7 @@ class MaskFormerSwinSelfOutput(nn.Module): ...@@ -764,7 +764,7 @@ class MaskFormerSwinSelfOutput(nn.Module):
self.dense = nn.Linear(dim, dim) self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -797,7 +797,13 @@ class MaskFormerSwinAttention(nn.Module): ...@@ -797,7 +797,13 @@ class MaskFormerSwinAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
...@@ -814,7 +820,7 @@ class MaskFormerSwinIntermediate(nn.Module): ...@@ -814,7 +820,7 @@ class MaskFormerSwinIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -827,7 +833,7 @@ class MaskFormerSwinOutput(nn.Module): ...@@ -827,7 +833,7 @@ class MaskFormerSwinOutput(nn.Module):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
return hidden_states return hidden_states
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -154,8 +154,13 @@ class MegatronBertEmbeddings(nn.Module): ...@@ -154,8 +154,13 @@ class MegatronBertEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -319,7 +324,7 @@ class MegatronBertSelfOutput(nn.Module): ...@@ -319,7 +324,7 @@ class MegatronBertSelfOutput(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual): def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
return residual + hidden_states return residual + hidden_states
...@@ -354,14 +359,14 @@ class MegatronBertAttention(nn.Module): ...@@ -354,14 +359,14 @@ class MegatronBertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
ln_outputs = self.ln(hidden_states) ln_outputs = self.ln(hidden_states)
self_outputs = self.self( self_outputs = self.self(
ln_outputs, ln_outputs,
...@@ -400,7 +405,7 @@ class MegatronBertOutput(nn.Module): ...@@ -400,7 +405,7 @@ class MegatronBertOutput(nn.Module):
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
return input_tensor + hidden_states return input_tensor + hidden_states
...@@ -425,14 +430,14 @@ class MegatronBertLayer(nn.Module): ...@@ -425,14 +430,14 @@ class MegatronBertLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -507,17 +512,17 @@ class MegatronBertEncoder(nn.Module): ...@@ -507,17 +512,17 @@ class MegatronBertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -873,20 +878,20 @@ class MegatronBertModel(MegatronBertPreTrainedModel): ...@@ -873,20 +878,20 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -1022,18 +1027,18 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel): ...@@ -1022,18 +1027,18 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
next_sentence_label=None, next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MegatronBertForPreTrainingOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1133,21 +1138,21 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): ...@@ -1133,21 +1138,21 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -1287,19 +1292,19 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): ...@@ -1287,19 +1292,19 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1379,18 +1384,18 @@ class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel): ...@@ -1379,18 +1384,18 @@ class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs **kwargs
): ) -> Union[Tuple, NextSentencePredictorOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
...@@ -1489,17 +1494,17 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel): ...@@ -1489,17 +1494,17 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1588,17 +1593,17 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel): ...@@ -1588,17 +1593,17 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1684,17 +1689,17 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel): ...@@ -1684,17 +1689,17 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1765,18 +1770,18 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel): ...@@ -1765,18 +1770,18 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -164,7 +164,7 @@ class NoNorm(nn.Module): ...@@ -164,7 +164,7 @@ class NoNorm(nn.Module):
self.bias = nn.Parameter(torch.zeros(feat_size)) self.bias = nn.Parameter(torch.zeros(feat_size))
self.weight = nn.Parameter(torch.ones(feat_size)) self.weight = nn.Parameter(torch.ones(feat_size))
def forward(self, input_tensor): def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return input_tensor * self.weight + self.bias return input_tensor * self.weight + self.bias
...@@ -194,7 +194,13 @@ class MobileBertEmbeddings(nn.Module): ...@@ -194,7 +194,13 @@ class MobileBertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -260,13 +266,13 @@ class MobileBertSelfAttention(nn.Module): ...@@ -260,13 +266,13 @@ class MobileBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
query_tensor, query_tensor: torch.Tensor,
key_tensor, key_tensor: torch.Tensor,
value_tensor, value_tensor: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
): ) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(query_tensor) mixed_query_layer = self.query(query_tensor)
mixed_key_layer = self.key(key_tensor) mixed_key_layer = self.key(key_tensor)
mixed_value_layer = self.value(value_tensor) mixed_value_layer = self.value(value_tensor)
...@@ -306,7 +312,7 @@ class MobileBertSelfOutput(nn.Module): ...@@ -306,7 +312,7 @@ class MobileBertSelfOutput(nn.Module):
if not self.use_bottleneck: if not self.use_bottleneck:
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual_tensor): def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states) layer_outputs = self.dense(hidden_states)
if not self.use_bottleneck: if not self.use_bottleneck:
layer_outputs = self.dropout(layer_outputs) layer_outputs = self.dropout(layer_outputs)
...@@ -341,14 +347,14 @@ class MobileBertAttention(nn.Module): ...@@ -341,14 +347,14 @@ class MobileBertAttention(nn.Module):
def forward( def forward(
self, self,
query_tensor, query_tensor: torch.Tensor,
key_tensor, key_tensor: torch.Tensor,
value_tensor, value_tensor: torch.Tensor,
layer_input, layer_input: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
): ) -> Tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
query_tensor, query_tensor,
key_tensor, key_tensor,
...@@ -373,7 +379,7 @@ class MobileBertIntermediate(nn.Module): ...@@ -373,7 +379,7 @@ class MobileBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -386,7 +392,7 @@ class OutputBottleneck(nn.Module): ...@@ -386,7 +392,7 @@ class OutputBottleneck(nn.Module):
self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, residual_tensor): def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states) layer_outputs = self.dense(hidden_states)
layer_outputs = self.dropout(layer_outputs) layer_outputs = self.dropout(layer_outputs)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
...@@ -404,7 +410,9 @@ class MobileBertOutput(nn.Module): ...@@ -404,7 +410,9 @@ class MobileBertOutput(nn.Module):
else: else:
self.bottleneck = OutputBottleneck(config) self.bottleneck = OutputBottleneck(config)
def forward(self, intermediate_states, residual_tensor_1, residual_tensor_2): def forward(
self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
) -> torch.Tensor:
layer_output = self.dense(intermediate_states) layer_output = self.dense(intermediate_states)
if not self.use_bottleneck: if not self.use_bottleneck:
layer_output = self.dropout(layer_output) layer_output = self.dropout(layer_output)
...@@ -421,7 +429,7 @@ class BottleneckLayer(nn.Module): ...@@ -421,7 +429,7 @@ class BottleneckLayer(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size) self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps) self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
layer_input = self.dense(hidden_states) layer_input = self.dense(hidden_states)
layer_input = self.LayerNorm(layer_input) layer_input = self.LayerNorm(layer_input)
return layer_input return layer_input
...@@ -436,7 +444,7 @@ class Bottleneck(nn.Module): ...@@ -436,7 +444,7 @@ class Bottleneck(nn.Module):
if self.key_query_shared_bottleneck: if self.key_query_shared_bottleneck:
self.attention = BottleneckLayer(config) self.attention = BottleneckLayer(config)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
# This method can return three different tuples of values. These different values make use of bottlenecks, # This method can return three different tuples of values. These different values make use of bottlenecks,
# which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
# usage. These linear layer have weights that are learned during training. # usage. These linear layer have weights that are learned during training.
...@@ -469,7 +477,7 @@ class FFNOutput(nn.Module): ...@@ -469,7 +477,7 @@ class FFNOutput(nn.Module):
self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size) self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps) self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, residual_tensor): def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
layer_outputs = self.dense(hidden_states) layer_outputs = self.dense(hidden_states)
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor) layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
return layer_outputs return layer_outputs
...@@ -481,7 +489,7 @@ class FFNLayer(nn.Module): ...@@ -481,7 +489,7 @@ class FFNLayer(nn.Module):
self.intermediate = MobileBertIntermediate(config) self.intermediate = MobileBertIntermediate(config)
self.output = FFNOutput(config) self.output = FFNOutput(config)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
intermediate_output = self.intermediate(hidden_states) intermediate_output = self.intermediate(hidden_states)
layer_outputs = self.output(intermediate_output, hidden_states) layer_outputs = self.output(intermediate_output, hidden_states)
return layer_outputs return layer_outputs
...@@ -503,11 +511,11 @@ class MobileBertLayer(nn.Module): ...@@ -503,11 +511,11 @@ class MobileBertLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
): ) -> Tuple[torch.Tensor]:
if self.use_bottleneck: if self.use_bottleneck:
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states) query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
else: else:
...@@ -557,13 +565,13 @@ class MobileBertEncoder(nn.Module): ...@@ -557,13 +565,13 @@ class MobileBertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
...@@ -599,7 +607,7 @@ class MobileBertPooler(nn.Module): ...@@ -599,7 +607,7 @@ class MobileBertPooler(nn.Module):
if self.do_activate: if self.do_activate:
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -621,7 +629,7 @@ class MobileBertPredictionHeadTransform(nn.Module): ...@@ -621,7 +629,7 @@ class MobileBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -640,7 +648,7 @@ class MobileBertLMPredictionHead(nn.Module): ...@@ -640,7 +648,7 @@ class MobileBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.decoder.bias hidden_states += self.decoder.bias
...@@ -652,7 +660,7 @@ class MobileBertOnlyMLMHead(nn.Module): ...@@ -652,7 +660,7 @@ class MobileBertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = MobileBertLMPredictionHead(config) self.predictions = MobileBertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
...@@ -663,7 +671,7 @@ class MobileBertPreTrainingHeads(nn.Module): ...@@ -663,7 +671,7 @@ class MobileBertPreTrainingHeads(nn.Module):
self.predictions = MobileBertLMPredictionHead(config) self.predictions = MobileBertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2) self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output): def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> Tuple[torch.Tensor]:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score return prediction_scores, seq_relationship_score
...@@ -841,16 +849,16 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -841,16 +849,16 @@ class MobileBertModel(MobileBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -943,18 +951,18 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -943,18 +951,18 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
next_sentence_label=None, next_sentence_label: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[torch.FloatTensor] = None,
output_hidden_states=None, output_hidden_states: Optional[torch.FloatTensor] = None,
return_dict=None, return_dict: Optional[torch.FloatTensor] = None,
): ) -> Union[Tuple, MobileBertForPreTrainingOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1059,17 +1067,17 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1059,17 +1067,17 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1115,7 +1123,7 @@ class MobileBertOnlyNSPHead(nn.Module): ...@@ -1115,7 +1123,7 @@ class MobileBertOnlyNSPHead(nn.Module):
super().__init__() super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2) self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output): def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score return seq_relationship_score
...@@ -1138,18 +1146,18 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel): ...@@ -1138,18 +1146,18 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple, NextSentencePredictorOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
......
...@@ -19,7 +19,7 @@ import math ...@@ -19,7 +19,7 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce from functools import reduce
from operator import __add__ from operator import __add__
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -177,7 +177,7 @@ class PerceiverEmbeddings(nn.Module): ...@@ -177,7 +177,7 @@ class PerceiverEmbeddings(nn.Module):
super().__init__() super().__init__()
self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents))
def forward(self, batch_size): def forward(self, batch_size: int):
return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
...@@ -232,13 +232,13 @@ class PerceiverSelfAttention(nn.Module): ...@@ -232,13 +232,13 @@ class PerceiverSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs=None, inputs: Optional[torch.FloatTensor] = None,
inputs_mask=None, inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
hidden_states = self.layernorm1(hidden_states) hidden_states = self.layernorm1(hidden_states)
inputs = self.layernorm2(inputs) inputs = self.layernorm2(inputs)
...@@ -301,7 +301,7 @@ class PerceiverSelfOutput(nn.Module): ...@@ -301,7 +301,7 @@ class PerceiverSelfOutput(nn.Module):
super().__init__() super().__init__()
self.dense = nn.Linear(input_channels, output_channels) self.dense = nn.Linear(input_channels, output_channels)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
return hidden_states return hidden_states
...@@ -377,13 +377,13 @@ class PerceiverAttention(nn.Module): ...@@ -377,13 +377,13 @@ class PerceiverAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs=None, inputs: Optional[torch.FloatTensor] = None,
inputs_mask=None, inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -418,7 +418,7 @@ class PerceiverMLP(nn.Module): ...@@ -418,7 +418,7 @@ class PerceiverMLP(nn.Module):
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(widening_factor * input_size, input_size) self.dense2 = nn.Linear(widening_factor * input_size, input_size)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense1(hidden_states) hidden_states = self.dense1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense2(hidden_states) hidden_states = self.dense2(hidden_states)
...@@ -456,13 +456,13 @@ class PerceiverLayer(nn.Module): ...@@ -456,13 +456,13 @@ class PerceiverLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs=None, inputs: Optional[torch.FloatTensor] = None,
inputs_mask=None, inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
attention_outputs = self.attention( attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -543,15 +543,15 @@ class PerceiverEncoder(nn.Module): ...@@ -543,15 +543,15 @@ class PerceiverEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs=None, inputs: Optional[torch.FloatTensor] = None,
inputs_mask=None, inputs_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
...@@ -754,14 +754,14 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -754,14 +754,14 @@ class PerceiverModel(PerceiverPreTrainedModel):
@replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
inputs, inputs: torch.FloatTensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
subsampled_output_points=None, subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, PerceiverModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1871,7 +1871,7 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1871,7 +1871,7 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
subsampled_output_points: Optional[Dict[str, torch.tensor]] = None, subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -2020,7 +2020,9 @@ class PerceiverProjectionDecoder(PerceiverAbstractDecoder): ...@@ -2020,7 +2020,9 @@ class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
return None return None
def forward(self, query, z, query_mask=None): def forward(
self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
# (batch_size, num_latents, d_latents) -> (batch_size, d_latents) # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
z = torch.mean(z, dim=1) z = torch.mean(z, dim=1)
# (batch_size, d_latents) -> (batch_size, config.num_labels) # (batch_size, d_latents) -> (batch_size, config.num_labels)
...@@ -2044,11 +2046,11 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2044,11 +2046,11 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none". The type of position encoding to use. Can be either "trainable", "fourier", or "none".
output_index_dims (`int`, *optional*): output_index_dims (`int`, *optional*):
The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'. The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
num_channels (`int`, *optional*): num_channels (`int`, *optional*, defaults to 128):
The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'. The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
qk_channels (`int`, *optional*): qk_channels (`int`, *optional*):
The number of channels of the queries and keys in the cross-attention layer. The number of channels of the queries and keys in the cross-attention layer.
v_channels (`int`, *optional*, defaults to 128): v_channels (`int`, *optional*):
The number of channels of the values in the cross-attention layer. The number of channels of the values in the cross-attention layer.
num_heads (`int`, *optional*, defaults to 1): num_heads (`int`, *optional*, defaults to 1):
The number of attention heads in the cross-attention layer. The number of attention heads in the cross-attention layer.
...@@ -2066,23 +2068,23 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2066,23 +2068,23 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
def __init__( def __init__(
self, self,
config, config: PerceiverConfig,
output_num_channels, output_num_channels: int,
position_encoding_type="trainable", position_encoding_type: Optional[str] = "trainable",
# The following 2 arguments are ignored if position_encoding_type == 'none': # The following 2 arguments are ignored if position_encoding_type == 'none':
output_index_dims=None, output_index_dims: Optional[int] = None,
num_channels=128, num_channels: Optional[int] = 128,
subsampled_index_dims=None, subsampled_index_dims: Optional[int] = None,
qk_channels=None, qk_channels: Optional[int] = None,
v_channels=None, v_channels: Optional[int] = None,
num_heads=1, num_heads: Optional[int] = 1,
widening_factor=1, widening_factor: Optional[int] = 1,
use_query_residual=False, use_query_residual: Optional[bool] = False,
concat_preprocessed_input=False, concat_preprocessed_input: Optional[bool] = False,
final_project=True, final_project: Optional[bool] = True,
position_encoding_only=False, position_encoding_only: Optional[bool] = False,
**position_encoding_kwargs, **position_encoding_kwargs,
): ) -> None:
super().__init__() super().__init__()
self.output_num_channels = output_num_channels self.output_num_channels = output_num_channels
...@@ -2183,7 +2185,13 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2183,7 +2185,13 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
return pos_emb return pos_emb
def forward(self, query, z, query_mask=None, output_attentions=False): def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
# Cross-attention decoding. # Cross-attention decoding.
# key, value: B x N x K; query: B x M x K # key, value: B x N x K; query: B x M x K
# Attention maps -> B x N x M # Attention maps -> B x N x M
...@@ -2239,7 +2247,13 @@ class PerceiverClassificationDecoder(PerceiverAbstractDecoder): ...@@ -2239,7 +2247,13 @@ class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
) )
def forward(self, query, z, query_mask=None, output_attentions=False): def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
# B x 1 x num_classes -> B x num_classes # B x 1 x num_classes -> B x num_classes
...@@ -2268,7 +2282,13 @@ class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder): ...@@ -2268,7 +2282,13 @@ class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
raise ValueError("FlowDecoder doesn't support subsampling yet.") raise ValueError("FlowDecoder doesn't support subsampling yet.")
return inputs return inputs
def forward(self, query, z, query_mask=None, output_attentions=False): def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
preds = decoder_outputs.logits preds = decoder_outputs.logits
# Output flow and rescale. # Output flow and rescale.
...@@ -2291,7 +2311,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): ...@@ -2291,7 +2311,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
The type of position encoding to use. Can be either "trainable", "fourier", or "none". The type of position encoding to use. Can be either "trainable", "fourier", or "none".
""" """
def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs): def __init__(
self, config: PerceiverConfig, output_shape: List[int], position_encoding_type: str, **decoder_kwargs
) -> None:
super().__init__() super().__init__()
if len(output_shape) != 4: # B, T, H, W if len(output_shape) != 4: # B, T, H, W
raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.") raise ValueError(f"Expected rank 4 output_shape, got {output_shape}.")
...@@ -2318,7 +2340,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): ...@@ -2318,7 +2340,9 @@ class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
subsampled_points=subsampled_points, subsampled_points=subsampled_points,
) )
def forward(self, query, z, query_mask=None): def forward(
self, query: torch.Tensor, z: torch.FloatTensor, query_mask: Optional[torch.FloatTensor] = None
) -> PerceiverDecoderOutput:
decoder_outputs = self.decoder(query, z) decoder_outputs = self.decoder(query, z)
logits = decoder_outputs.logits logits = decoder_outputs.logits
...@@ -2378,14 +2402,14 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): ...@@ -2378,14 +2402,14 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
def __init__( def __init__(
self, self,
config, config: PerceiverConfig,
modalities, modalities: Dict[str, PerceiverAbstractDecoder],
num_outputs, num_outputs: int,
output_num_channels, output_num_channels: int,
min_padding_size=2, min_padding_size: Optional[int] = 2,
subsampled_index_dims=None, subsampled_index_dims: Optional[Dict[str, PerceiverAbstractDecoder]] = None,
**decoder_kwargs **decoder_kwargs
): ) -> None:
super().__init__() super().__init__()
self.modalities = nn.ModuleDict(modalities) self.modalities = nn.ModuleDict(modalities)
self.subsampled_index_dims = subsampled_index_dims self.subsampled_index_dims = subsampled_index_dims
...@@ -2447,7 +2471,13 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): ...@@ -2447,7 +2471,13 @@ class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
[embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1 [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
) )
def forward(self, query, z, query_mask=None, output_attentions=False): def forward(
self,
query: torch.Tensor,
z: torch.FloatTensor,
query_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
# B x 1 x num_classes -> B x num_classes # B x 1 x num_classes -> B x num_classes
decoder_outputs = self.decoder(query, z, output_attentions=output_attentions) decoder_outputs = self.decoder(query, z, output_attentions=output_attentions)
...@@ -2680,7 +2710,7 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): ...@@ -2680,7 +2710,7 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
def output_size(self, *args, **kwargs) -> int: def output_size(self, *args, **kwargs) -> int:
return self._num_channels return self._num_channels
def forward(self, batch_size): def forward(self, batch_size: int) -> torch.Tensor:
position_embeddings = self.position_embeddings position_embeddings = self.position_embeddings
if batch_size is not None: if batch_size is not None:
...@@ -2741,7 +2771,9 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): ...@@ -2741,7 +2771,9 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
return encoding_size return encoding_size
def forward(self, index_dims, batch_size, device, pos=None): def forward(
self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None
) -> torch.FloatTensor:
pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
fourier_pos_enc = generate_fourier_features( fourier_pos_enc = generate_fourier_features(
pos, pos,
...@@ -2771,7 +2803,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): ...@@ -2771,7 +2803,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
Model configuration. Model configuration.
""" """
def __init__(self, config): def __init__(self, config: PerceiverConfig) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model) self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
...@@ -2781,7 +2813,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): ...@@ -2781,7 +2813,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def num_channels(self) -> int: def num_channels(self) -> int:
return self.config.d_model return self.config.d_model
def forward(self, inputs): def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor:
embeddings = self.embeddings(inputs) embeddings = self.embeddings(inputs)
seq_length = inputs.shape[1] seq_length = inputs.shape[1]
...@@ -2800,13 +2832,13 @@ class PerceiverEmbeddingDecoder(nn.Module): ...@@ -2800,13 +2832,13 @@ class PerceiverEmbeddingDecoder(nn.Module):
Model configuration. Model configuration.
""" """
def __init__(self, config): def __init__(self, config: PerceiverConfig) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.bias = nn.Parameter(torch.zeros(self.vocab_size)) self.bias = nn.Parameter(torch.zeros(self.vocab_size))
def forward(self, hidden_states, embedding_layer): def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, d_model = hidden_states.shape batch_size, seq_len, d_model = hidden_states.shape
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
output = output + self.bias output = output + self.bias
...@@ -2859,7 +2891,7 @@ class PerceiverClassificationPostprocessor(nn.Module): ...@@ -2859,7 +2891,7 @@ class PerceiverClassificationPostprocessor(nn.Module):
Number of channels in the input. Number of channels in the input.
""" """
def __init__(self, config, in_channels): def __init__(self, config: PerceiverConfig, in_channels: int) -> None:
super().__init__() super().__init__()
self.classifier = nn.Linear(in_channels, config.num_labels) self.classifier = nn.Linear(in_channels, config.num_labels)
...@@ -2881,7 +2913,7 @@ class PerceiverAudioPostprocessor(nn.Module): ...@@ -2881,7 +2913,7 @@ class PerceiverAudioPostprocessor(nn.Module):
Postprocessor type to use. Currently, only "patches" is supported. Postprocessor type to use. Currently, only "patches" is supported.
""" """
def __init__(self, config, in_channels, postproc_type: str = "patches"): def __init__(self, config: PerceiverConfig, in_channels: int, postproc_type: str = "patches") -> None:
super().__init__() super().__init__()
if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels' if postproc_type not in ("patches",): # to be supported: 'conv', 'patches', 'pixels'
...@@ -2908,7 +2940,7 @@ class PerceiverProjectionPostprocessor(nn.Module): ...@@ -2908,7 +2940,7 @@ class PerceiverProjectionPostprocessor(nn.Module):
Number of channels in the output. Number of channels in the output.
""" """
def __init__(self, in_channels, out_channels): def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__() super().__init__()
self.classifier = nn.Linear(in_channels, out_channels) self.classifier = nn.Linear(in_channels, out_channels)
...@@ -3155,7 +3187,7 @@ class PerceiverOneHotPreprocessor(AbstractPreprocessor): ...@@ -3155,7 +3187,7 @@ class PerceiverOneHotPreprocessor(AbstractPreprocessor):
Model configuration. Model configuration.
""" """
def __init__(self, config): def __init__(self, config: PerceiverConfig) -> None:
super().__init__() super().__init__()
self.config: PerceiverConfig = config self.config: PerceiverConfig = config
......
...@@ -18,6 +18,7 @@ RetriBERT model ...@@ -18,6 +18,7 @@ RetriBERT model
import math import math
from typing import Optional
import torch import torch
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
...@@ -85,7 +86,7 @@ RETRIBERT_START_DOCSTRING = r""" ...@@ -85,7 +86,7 @@ RETRIBERT_START_DOCSTRING = r"""
RETRIBERT_START_DOCSTRING, RETRIBERT_START_DOCSTRING,
) )
class RetriBertModel(RetriBertPreTrainedModel): class RetriBertModel(RetriBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: RetriBertConfig) -> None:
super().__init__(config) super().__init__(config)
self.projection_dim = config.projection_dim self.projection_dim = config.projection_dim
...@@ -173,8 +174,13 @@ class RetriBertModel(RetriBertPreTrainedModel): ...@@ -173,8 +174,13 @@ class RetriBertModel(RetriBertPreTrainedModel):
return self.project_doc(a_reps) return self.project_doc(a_reps)
def forward( def forward(
self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1 self,
): input_ids_query: torch.LongTensor,
attention_mask_query: Optional[torch.FloatTensor],
input_ids_doc: torch.LongTensor,
attention_mask_doc: Optional[torch.FloatTensor],
checkpoint_batch_size: int = -1,
) -> torch.FloatTensor:
r""" r"""
Args: Args:
input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
......
...@@ -112,7 +112,7 @@ class SegformerDropPath(nn.Module): ...@@ -112,7 +112,7 @@ class SegformerDropPath(nn.Module):
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -272,7 +272,9 @@ class SwinEmbeddings(nn.Module): ...@@ -272,7 +272,9 @@ class SwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None): def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -317,7 +319,7 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -317,7 +319,7 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values): def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
_, _, height, width = pixel_values.shape _, _, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
...@@ -342,7 +344,7 @@ class SwinPatchMerging(nn.Module): ...@@ -342,7 +344,7 @@ class SwinPatchMerging(nn.Module):
Normalization layer class. Normalization layer class.
""" """
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
super().__init__() super().__init__()
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.dim = dim self.dim = dim
...@@ -357,7 +359,7 @@ class SwinPatchMerging(nn.Module): ...@@ -357,7 +359,7 @@ class SwinPatchMerging(nn.Module):
return input_feature return input_feature
def forward(self, input_feature, input_dimensions): def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
height, width = input_dimensions height, width = input_dimensions
# `dim` is height * width # `dim` is height * width
batch_size, dim, num_channels = input_feature.shape batch_size, dim, num_channels = input_feature.shape
...@@ -438,11 +440,11 @@ class SwinSelfAttention(nn.Module): ...@@ -438,11 +440,11 @@ class SwinSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -499,7 +501,7 @@ class SwinSelfOutput(nn.Module): ...@@ -499,7 +501,7 @@ class SwinSelfOutput(nn.Module):
self.dense = nn.Linear(dim, dim) self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -531,7 +533,13 @@ class SwinAttention(nn.Module): ...@@ -531,7 +533,13 @@ class SwinAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
...@@ -547,7 +555,7 @@ class SwinIntermediate(nn.Module): ...@@ -547,7 +555,7 @@ class SwinIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -559,7 +567,7 @@ class SwinOutput(nn.Module): ...@@ -559,7 +567,7 @@ class SwinOutput(nn.Module):
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
return hidden_states return hidden_states
...@@ -621,7 +629,13 @@ class SwinLayer(nn.Module): ...@@ -621,7 +629,13 @@ class SwinLayer(nn.Module):
hidden_states = nn.functional.pad(hidden_states, pad_values) hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.set_shift_and_window_size(input_dimensions) self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions height, width = input_dimensions
batch_size, _, channels = hidden_states.size() batch_size, _, channels = hidden_states.size()
...@@ -703,7 +717,13 @@ class SwinStage(nn.Module): ...@@ -703,7 +717,13 @@ class SwinStage(nn.Module):
self.pointing = False self.pointing = False
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: Tuple[int, int],
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
height, width = input_dimensions height, width = input_dimensions
for i, layer_module in enumerate(self.blocks): for i, layer_module in enumerate(self.blocks):
...@@ -752,13 +772,13 @@ class SwinEncoder(nn.Module): ...@@ -752,13 +772,13 @@ class SwinEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
input_dimensions, input_dimensions: Tuple[int, int],
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = () all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None
...@@ -920,13 +940,13 @@ class SwinModel(SwinPreTrainedModel): ...@@ -920,13 +940,13 @@ class SwinModel(SwinPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SwinModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -999,13 +1019,13 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -999,13 +1019,13 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
@replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
bool_masked_pos=None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -1115,13 +1135,13 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -1115,13 +1135,13 @@ class SwinForImageClassification(SwinPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SwinImageClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
...@@ -881,14 +881,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -881,14 +881,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
mems=None, mems: Optional[List[torch.FloatTensor]] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TransfoXLModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -1071,15 +1071,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1071,15 +1071,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
mems=None, mems: Optional[List[torch.FloatTensor]] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TransfoXLLMHeadModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -1215,11 +1215,11 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): ...@@ -1215,11 +1215,11 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.LongTensor] = None,
mems: Optional[List[torch.FloatTensor]] = None, mems: Optional[List[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import math import math
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -81,7 +82,7 @@ class VanDropPath(nn.Module): ...@@ -81,7 +82,7 @@ class VanDropPath(nn.Module):
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -146,7 +147,7 @@ class VanLargeKernelAttentionLayer(nn.Module): ...@@ -146,7 +147,7 @@ class VanLargeKernelAttentionLayer(nn.Module):
super().__init__() super().__init__()
self.attention = VanLargeKernelAttention(hidden_size) self.attention = VanLargeKernelAttention(hidden_size)
def forward(self, hidden_state): def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
attention = self.attention(hidden_state) attention = self.attention(hidden_state)
attended = hidden_state * attention attended = hidden_state * attention
return attended return attended
...@@ -171,7 +172,7 @@ class VanSpatialAttentionLayer(nn.Module): ...@@ -171,7 +172,7 @@ class VanSpatialAttentionLayer(nn.Module):
self.attention_layer = VanLargeKernelAttentionLayer(hidden_size) self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
def forward(self, hidden_state): def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state residual = hidden_state
hidden_state = self.pre_projection(hidden_state) hidden_state = self.pre_projection(hidden_state)
hidden_state = self.attention_layer(hidden_state) hidden_state = self.attention_layer(hidden_state)
...@@ -189,7 +190,7 @@ class VanLayerScaling(nn.Module): ...@@ -189,7 +190,7 @@ class VanLayerScaling(nn.Module):
super().__init__() super().__init__()
self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True) self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
def forward(self, hidden_state): def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
# unsqueezing for broadcasting # unsqueezing for broadcasting
hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
return hidden_state return hidden_state
...@@ -218,7 +219,7 @@ class VanLayer(nn.Module): ...@@ -218,7 +219,7 @@ class VanLayer(nn.Module):
) )
self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value) self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
def forward(self, hidden_state): def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state residual = hidden_state
# attention # attention
hidden_state = self.pre_normomalization(hidden_state) hidden_state = self.pre_normomalization(hidden_state)
...@@ -269,7 +270,7 @@ class VanStage(nn.Module): ...@@ -269,7 +270,7 @@ class VanStage(nn.Module):
) )
self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_state): def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.embeddings(hidden_state) hidden_state = self.embeddings(hidden_state)
hidden_state = self.layers(hidden_state) hidden_state = self.layers(hidden_state)
# rearrange b c h w -> b (h w) c # rearrange b c h w -> b (h w) c
...@@ -316,7 +317,12 @@ class VanEncoder(nn.Module): ...@@ -316,7 +317,12 @@ class VanEncoder(nn.Module):
) )
) )
def forward(self, hidden_state, output_hidden_states=False, return_dict=True): def forward(
self,
hidden_state: torch.Tensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for _, stage_module in enumerate(self.stages): for _, stage_module in enumerate(self.stages):
...@@ -411,7 +417,12 @@ class VanModel(VanPreTrainedModel): ...@@ -411,7 +417,12 @@ class VanModel(VanPreTrainedModel):
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward(self, pixel_values, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: Optional[torch.FloatTensor],
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
...@@ -463,7 +474,13 @@ class VanForImageClassification(VanPreTrainedModel): ...@@ -463,7 +474,13 @@ class VanForImageClassification(VanPreTrainedModel):
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
) )
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment