Unverified Commit ec3aace0 authored by Jacob Dineen's avatar Jacob Dineen Committed by GitHub
Browse files

Add type annotations for Rembert/Splinter and copies (#16338)



* undo black autoformat

* minor fix to rembert forward with default

* make fix-copies, make quality

* Adding types to template model

* Removing List from the template types

* Remove `Optional` from a couple of types that don't accept `None`
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent c30798ec
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import math import math
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -69,8 +70,13 @@ class SplinterEmbeddings(nn.Module): ...@@ -69,8 +70,13 @@ class SplinterEmbeddings(nn.Module):
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: Optional[int] = 0,
) -> Tuple:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -132,14 +138,14 @@ class SplinterSelfAttention(nn.Module): ...@@ -132,14 +138,14 @@ class SplinterSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -233,7 +239,7 @@ class SplinterSelfOutput(nn.Module): ...@@ -233,7 +239,7 @@ class SplinterSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -268,14 +274,14 @@ class SplinterAttention(nn.Module): ...@@ -268,14 +274,14 @@ class SplinterAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -300,7 +306,7 @@ class SplinterIntermediate(nn.Module): ...@@ -300,7 +306,7 @@ class SplinterIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -314,7 +320,7 @@ class SplinterOutput(nn.Module): ...@@ -314,7 +320,7 @@ class SplinterOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -339,14 +345,14 @@ class SplinterLayer(nn.Module): ...@@ -339,14 +345,14 @@ class SplinterLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -417,17 +423,17 @@ class SplinterEncoder(nn.Module): ...@@ -417,17 +423,17 @@ class SplinterEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -643,20 +649,20 @@ class SplinterModel(SplinterPreTrainedModel): ...@@ -643,20 +649,20 @@ class SplinterModel(SplinterPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -773,7 +779,7 @@ class SplinterFullyConnectedLayer(nn.Module): ...@@ -773,7 +779,7 @@ class SplinterFullyConnectedLayer(nn.Module):
self.act_fn = ACT2FN[hidden_act] self.act_fn = ACT2FN[hidden_act]
self.LayerNorm = nn.LayerNorm(self.output_dim) self.LayerNorm = nn.LayerNorm(self.output_dim)
def forward(self, inputs): def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(inputs) hidden_states = self.dense(inputs)
hidden_states = self.act_fn(hidden_states) hidden_states = self.act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -845,19 +851,19 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel): ...@@ -845,19 +851,19 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
question_positions=None, question_positions: Optional[torch.LongTensor] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -449,7 +449,7 @@ class TapasSelfOutput(nn.Module): ...@@ -449,7 +449,7 @@ class TapasSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -485,14 +485,14 @@ class TapasAttention(nn.Module): ...@@ -485,14 +485,14 @@ class TapasAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -517,7 +517,7 @@ class TapasIntermediate(nn.Module): ...@@ -517,7 +517,7 @@ class TapasIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -531,7 +531,7 @@ class TapasOutput(nn.Module): ...@@ -531,7 +531,7 @@ class TapasOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -556,14 +556,14 @@ class TapasLayer(nn.Module): ...@@ -556,14 +556,14 @@ class TapasLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -700,7 +700,7 @@ class TapasPooler(nn.Module): ...@@ -700,7 +700,7 @@ class TapasPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -720,7 +720,7 @@ class TapasPredictionHeadTransform(nn.Module): ...@@ -720,7 +720,7 @@ class TapasPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -754,7 +754,7 @@ class TapasOnlyMLMHead(nn.Module): ...@@ -754,7 +754,7 @@ class TapasOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = TapasLMPredictionHead(config) self.predictions = TapasLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -273,7 +273,7 @@ class VisualBertSelfOutput(nn.Module): ...@@ -273,7 +273,7 @@ class VisualBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -333,7 +333,7 @@ class VisualBertIntermediate(nn.Module): ...@@ -333,7 +333,7 @@ class VisualBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -347,7 +347,7 @@ class VisualBertOutput(nn.Module): ...@@ -347,7 +347,7 @@ class VisualBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -464,7 +464,7 @@ class VisualBertPooler(nn.Module): ...@@ -464,7 +464,7 @@ class VisualBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -484,7 +484,7 @@ class VisualBertPredictionHeadTransform(nn.Module): ...@@ -484,7 +484,7 @@ class VisualBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
...@@ -187,14 +187,14 @@ class XLMRobertaXLSelfAttention(nn.Module): ...@@ -187,14 +187,14 @@ class XLMRobertaXLSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -354,7 +354,7 @@ class XLMRobertaXLIntermediate(nn.Module): ...@@ -354,7 +354,7 @@ class XLMRobertaXLIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -565,7 +565,7 @@ class XLMRobertaXLPooler(nn.Module): ...@@ -565,7 +565,7 @@ class XLMRobertaXLPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -445,7 +445,7 @@ class YosoSelfOutput(nn.Module): ...@@ -445,7 +445,7 @@ class YosoSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -494,7 +494,7 @@ class YosoIntermediate(nn.Module): ...@@ -494,7 +494,7 @@ class YosoIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -508,7 +508,7 @@ class YosoOutput(nn.Module): ...@@ -508,7 +508,7 @@ class YosoOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -610,7 +610,7 @@ class YosoPredictionHeadTransform(nn.Module): ...@@ -610,7 +610,7 @@ class YosoPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -644,7 +644,7 @@ class YosoOnlyMLMHead(nn.Module): ...@@ -644,7 +644,7 @@ class YosoOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = YosoLMPredictionHead(config) self.predictions = YosoLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -25,6 +25,7 @@ import torch.utils.checkpoint ...@@ -25,6 +25,7 @@ import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment