Unverified Commit ec3aace0 authored by Jacob Dineen's avatar Jacob Dineen Committed by GitHub
Browse files

Add type annotations for Rembert/Splinter and copies (#16338)



* undo black autoformat

* minor fix to rembert forward with default

* make fix-copies, make quality

* Adding types to template model

* Removing List from the template types

* Remove `Optional` from a couple of types that don't accept `None`
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent c30798ec
...@@ -257,14 +257,14 @@ class BertSelfAttention(nn.Module): ...@@ -257,14 +257,14 @@ class BertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -357,7 +357,7 @@ class BertSelfOutput(nn.Module): ...@@ -357,7 +357,7 @@ class BertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -391,14 +391,14 @@ class BertAttention(nn.Module): ...@@ -391,14 +391,14 @@ class BertAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -422,7 +422,7 @@ class BertIntermediate(nn.Module): ...@@ -422,7 +422,7 @@ class BertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -435,7 +435,7 @@ class BertOutput(nn.Module): ...@@ -435,7 +435,7 @@ class BertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -459,14 +459,14 @@ class BertLayer(nn.Module): ...@@ -459,14 +459,14 @@ class BertLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -536,17 +536,17 @@ class BertEncoder(nn.Module): ...@@ -536,17 +536,17 @@ class BertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -630,7 +630,7 @@ class BertPooler(nn.Module): ...@@ -630,7 +630,7 @@ class BertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -649,7 +649,7 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -649,7 +649,7 @@ class BertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -681,7 +681,7 @@ class BertOnlyMLMHead(nn.Module): ...@@ -681,7 +681,7 @@ class BertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = BertLMPredictionHead(config) self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -1311,7 +1311,7 @@ class BigBirdSelfOutput(nn.Module): ...@@ -1311,7 +1311,7 @@ class BigBirdSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -1412,7 +1412,7 @@ class BigBirdIntermediate(nn.Module): ...@@ -1412,7 +1412,7 @@ class BigBirdIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -1426,7 +1426,7 @@ class BigBirdOutput(nn.Module): ...@@ -1426,7 +1426,7 @@ class BigBirdOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -1684,7 +1684,7 @@ class BigBirdPredictionHeadTransform(nn.Module): ...@@ -1684,7 +1684,7 @@ class BigBirdPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -1718,7 +1718,7 @@ class BigBirdOnlyMLMHead(nn.Module): ...@@ -1718,7 +1718,7 @@ class BigBirdOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = BigBirdLMPredictionHead(config) self.predictions = BigBirdLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -193,14 +193,14 @@ class Data2VecTextSelfAttention(nn.Module): ...@@ -193,14 +193,14 @@ class Data2VecTextSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -294,7 +294,7 @@ class Data2VecTextSelfOutput(nn.Module): ...@@ -294,7 +294,7 @@ class Data2VecTextSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -329,14 +329,14 @@ class Data2VecTextAttention(nn.Module): ...@@ -329,14 +329,14 @@ class Data2VecTextAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -361,7 +361,7 @@ class Data2VecTextIntermediate(nn.Module): ...@@ -361,7 +361,7 @@ class Data2VecTextIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -375,7 +375,7 @@ class Data2VecTextOutput(nn.Module): ...@@ -375,7 +375,7 @@ class Data2VecTextOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -400,14 +400,14 @@ class Data2VecTextLayer(nn.Module): ...@@ -400,14 +400,14 @@ class Data2VecTextLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -478,17 +478,17 @@ class Data2VecTextEncoder(nn.Module): ...@@ -478,17 +478,17 @@ class Data2VecTextEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -573,7 +573,7 @@ class Data2VecTextPooler(nn.Module): ...@@ -573,7 +573,7 @@ class Data2VecTextPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -313,7 +313,7 @@ class DebertaIntermediate(nn.Module): ...@@ -313,7 +313,7 @@ class DebertaIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -301,7 +301,7 @@ class DebertaV2Intermediate(nn.Module): ...@@ -301,7 +301,7 @@ class DebertaV2Intermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
...@@ -250,14 +250,14 @@ class ElectraSelfAttention(nn.Module): ...@@ -250,14 +250,14 @@ class ElectraSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -351,7 +351,7 @@ class ElectraSelfOutput(nn.Module): ...@@ -351,7 +351,7 @@ class ElectraSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -386,14 +386,14 @@ class ElectraAttention(nn.Module): ...@@ -386,14 +386,14 @@ class ElectraAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -418,7 +418,7 @@ class ElectraIntermediate(nn.Module): ...@@ -418,7 +418,7 @@ class ElectraIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -432,7 +432,7 @@ class ElectraOutput(nn.Module): ...@@ -432,7 +432,7 @@ class ElectraOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -457,14 +457,14 @@ class ElectraLayer(nn.Module): ...@@ -457,14 +457,14 @@ class ElectraLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -535,17 +535,17 @@ class ElectraEncoder(nn.Module): ...@@ -535,17 +535,17 @@ class ElectraEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
......
...@@ -231,7 +231,7 @@ class FNetIntermediate(nn.Module): ...@@ -231,7 +231,7 @@ class FNetIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -245,7 +245,7 @@ class FNetOutput(nn.Module): ...@@ -245,7 +245,7 @@ class FNetOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -323,7 +323,7 @@ class FNetPooler(nn.Module): ...@@ -323,7 +323,7 @@ class FNetPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -343,7 +343,7 @@ class FNetPredictionHeadTransform(nn.Module): ...@@ -343,7 +343,7 @@ class FNetPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
...@@ -166,14 +166,14 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -166,14 +166,14 @@ class LayoutLMSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -267,7 +267,7 @@ class LayoutLMSelfOutput(nn.Module): ...@@ -267,7 +267,7 @@ class LayoutLMSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -302,14 +302,14 @@ class LayoutLMAttention(nn.Module): ...@@ -302,14 +302,14 @@ class LayoutLMAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -334,7 +334,7 @@ class LayoutLMIntermediate(nn.Module): ...@@ -334,7 +334,7 @@ class LayoutLMIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -348,7 +348,7 @@ class LayoutLMOutput(nn.Module): ...@@ -348,7 +348,7 @@ class LayoutLMOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -373,14 +373,14 @@ class LayoutLMLayer(nn.Module): ...@@ -373,14 +373,14 @@ class LayoutLMLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -451,17 +451,17 @@ class LayoutLMEncoder(nn.Module): ...@@ -451,17 +451,17 @@ class LayoutLMEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -546,7 +546,7 @@ class LayoutLMPooler(nn.Module): ...@@ -546,7 +546,7 @@ class LayoutLMPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -566,7 +566,7 @@ class LayoutLMPredictionHeadTransform(nn.Module): ...@@ -566,7 +566,7 @@ class LayoutLMPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -600,7 +600,7 @@ class LayoutLMOnlyMLMHead(nn.Module): ...@@ -600,7 +600,7 @@ class LayoutLMOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = LayoutLMLMPredictionHead(config) self.predictions = LayoutLMLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -249,7 +249,7 @@ class LayoutLMv2Intermediate(nn.Module): ...@@ -249,7 +249,7 @@ class LayoutLMv2Intermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -263,7 +263,7 @@ class LayoutLMv2Output(nn.Module): ...@@ -263,7 +263,7 @@ class LayoutLMv2Output(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
......
...@@ -1104,7 +1104,7 @@ class LongformerSelfOutput(nn.Module): ...@@ -1104,7 +1104,7 @@ class LongformerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -1170,7 +1170,7 @@ class LongformerIntermediate(nn.Module): ...@@ -1170,7 +1170,7 @@ class LongformerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -1184,7 +1184,7 @@ class LongformerOutput(nn.Module): ...@@ -1184,7 +1184,7 @@ class LongformerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -1338,7 +1338,7 @@ class LongformerPooler(nn.Module): ...@@ -1338,7 +1338,7 @@ class LongformerPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -480,7 +480,7 @@ class LukeSelfOutput(nn.Module): ...@@ -480,7 +480,7 @@ class LukeSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -544,7 +544,7 @@ class LukeIntermediate(nn.Module): ...@@ -544,7 +544,7 @@ class LukeIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -558,7 +558,7 @@ class LukeOutput(nn.Module): ...@@ -558,7 +558,7 @@ class LukeOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -708,7 +708,7 @@ class LukePooler(nn.Module): ...@@ -708,7 +708,7 @@ class LukePooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -228,14 +228,14 @@ class MegatronBertSelfAttention(nn.Module): ...@@ -228,14 +228,14 @@ class MegatronBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -396,7 +396,7 @@ class MegatronBertIntermediate(nn.Module): ...@@ -396,7 +396,7 @@ class MegatronBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -615,7 +615,7 @@ class MegatronBertPooler(nn.Module): ...@@ -615,7 +615,7 @@ class MegatronBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -635,7 +635,7 @@ class MegatronBertPredictionHeadTransform(nn.Module): ...@@ -635,7 +635,7 @@ class MegatronBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -669,7 +669,7 @@ class MegatronBertOnlyMLMHead(nn.Module): ...@@ -669,7 +669,7 @@ class MegatronBertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = MegatronBertLMPredictionHead(config) self.predictions = MegatronBertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -259,7 +259,7 @@ class MPNetIntermediate(nn.Module): ...@@ -259,7 +259,7 @@ class MPNetIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -273,7 +273,7 @@ class MPNetOutput(nn.Module): ...@@ -273,7 +273,7 @@ class MPNetOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -408,7 +408,7 @@ class MPNetPooler(nn.Module): ...@@ -408,7 +408,7 @@ class MPNetPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -252,7 +252,7 @@ class NystromformerSelfOutput(nn.Module): ...@@ -252,7 +252,7 @@ class NystromformerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -301,7 +301,7 @@ class NystromformerIntermediate(nn.Module): ...@@ -301,7 +301,7 @@ class NystromformerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -315,7 +315,7 @@ class NystromformerOutput(nn.Module): ...@@ -315,7 +315,7 @@ class NystromformerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -417,7 +417,7 @@ class NystromformerPredictionHeadTransform(nn.Module): ...@@ -417,7 +417,7 @@ class NystromformerPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -451,7 +451,7 @@ class NystromformerOnlyMLMHead(nn.Module): ...@@ -451,7 +451,7 @@ class NystromformerOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = NystromformerLMPredictionHead(config) self.predictions = NystromformerLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -641,7 +641,7 @@ class QDQBertPooler(nn.Module): ...@@ -641,7 +641,7 @@ class QDQBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -661,7 +661,7 @@ class QDQBertPredictionHeadTransform(nn.Module): ...@@ -661,7 +661,7 @@ class QDQBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
...@@ -265,14 +265,14 @@ class RealmSelfAttention(nn.Module): ...@@ -265,14 +265,14 @@ class RealmSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -366,7 +366,7 @@ class RealmSelfOutput(nn.Module): ...@@ -366,7 +366,7 @@ class RealmSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -401,14 +401,14 @@ class RealmAttention(nn.Module): ...@@ -401,14 +401,14 @@ class RealmAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -433,7 +433,7 @@ class RealmIntermediate(nn.Module): ...@@ -433,7 +433,7 @@ class RealmIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -447,7 +447,7 @@ class RealmOutput(nn.Module): ...@@ -447,7 +447,7 @@ class RealmOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -472,14 +472,14 @@ class RealmLayer(nn.Module): ...@@ -472,14 +472,14 @@ class RealmLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -550,17 +550,17 @@ class RealmEncoder(nn.Module): ...@@ -550,17 +550,17 @@ class RealmEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -645,7 +645,7 @@ class RealmPooler(nn.Module): ...@@ -645,7 +645,7 @@ class RealmPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import os import os
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -165,8 +166,13 @@ class RemBertEmbeddings(nn.Module): ...@@ -165,8 +166,13 @@ class RemBertEmbeddings(nn.Module):
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward( def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
): input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
else: else:
...@@ -199,7 +205,7 @@ class RemBertPooler(nn.Module): ...@@ -199,7 +205,7 @@ class RemBertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
...@@ -236,14 +242,14 @@ class RemBertSelfAttention(nn.Module): ...@@ -236,14 +242,14 @@ class RemBertSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Tuple[Tuple[torch.FloatTensor]] = None,
output_attentions=False, output_attentions: bool = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -321,7 +327,7 @@ class RemBertSelfOutput(nn.Module): ...@@ -321,7 +327,7 @@ class RemBertSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -357,14 +363,14 @@ class RemBertAttention(nn.Module): ...@@ -357,14 +363,14 @@ class RemBertAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -389,7 +395,7 @@ class RemBertIntermediate(nn.Module): ...@@ -389,7 +395,7 @@ class RemBertIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -403,7 +409,7 @@ class RemBertOutput(nn.Module): ...@@ -403,7 +409,7 @@ class RemBertOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -428,14 +434,14 @@ class RemBertLayer(nn.Module): ...@@ -428,14 +434,14 @@ class RemBertLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -508,17 +514,18 @@ class RemBertEncoder(nn.Module): ...@@ -508,17 +514,18 @@ class RemBertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
hidden_states = self.embedding_hidden_mapping_in(hidden_states) hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -608,7 +615,7 @@ class RemBertPredictionHeadTransform(nn.Module): ...@@ -608,7 +615,7 @@ class RemBertPredictionHeadTransform(nn.Module):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -623,7 +630,7 @@ class RemBertLMPredictionHead(nn.Module): ...@@ -623,7 +630,7 @@ class RemBertLMPredictionHead(nn.Module):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.output_embedding_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -637,7 +644,7 @@ class RemBertOnlyMLMHead(nn.Module): ...@@ -637,7 +644,7 @@ class RemBertOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = RemBertLMPredictionHead(config) self.predictions = RemBertLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
...@@ -788,20 +795,20 @@ class RemBertModel(RemBertPreTrainedModel): ...@@ -788,20 +795,20 @@ class RemBertModel(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -941,19 +948,19 @@ class RemBertForMaskedLM(RemBertPreTrainedModel): ...@@ -941,19 +948,19 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1039,21 +1046,21 @@ class RemBertForCausalLM(RemBertPreTrainedModel): ...@@ -1039,21 +1046,21 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.LongTensor = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
...@@ -1186,17 +1193,17 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel): ...@@ -1186,17 +1193,17 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...@@ -1283,17 +1290,17 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel): ...@@ -1283,17 +1290,17 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, MultipleChoiceModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
...@@ -1376,17 +1383,17 @@ class RemBertForTokenClassification(RemBertPreTrainedModel): ...@@ -1376,17 +1383,17 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, TokenClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
...@@ -1455,18 +1462,18 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel): ...@@ -1455,18 +1462,18 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: torch.FloatTensor = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions=None, start_positions: Optional[torch.LongTensor] = None,
end_positions=None, end_positions: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
......
...@@ -193,14 +193,14 @@ class RobertaSelfAttention(nn.Module): ...@@ -193,14 +193,14 @@ class RobertaSelfAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
...@@ -294,7 +294,7 @@ class RobertaSelfOutput(nn.Module): ...@@ -294,7 +294,7 @@ class RobertaSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -329,14 +329,14 @@ class RobertaAttention(nn.Module): ...@@ -329,14 +329,14 @@ class RobertaAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -361,7 +361,7 @@ class RobertaIntermediate(nn.Module): ...@@ -361,7 +361,7 @@ class RobertaIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -375,7 +375,7 @@ class RobertaOutput(nn.Module): ...@@ -375,7 +375,7 @@ class RobertaOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -400,14 +400,14 @@ class RobertaLayer(nn.Module): ...@@ -400,14 +400,14 @@ class RobertaLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value=None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> Tuple:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
...@@ -478,17 +478,17 @@ class RobertaEncoder(nn.Module): ...@@ -478,17 +478,17 @@ class RobertaEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
...@@ -573,7 +573,7 @@ class RobertaPooler(nn.Module): ...@@ -573,7 +573,7 @@ class RobertaPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
first_token_tensor = hidden_states[:, 0] first_token_tensor = hidden_states[:, 0]
......
...@@ -359,7 +359,7 @@ class RoFormerSelfOutput(nn.Module): ...@@ -359,7 +359,7 @@ class RoFormerSelfOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -429,7 +429,7 @@ class RoFormerIntermediate(nn.Module): ...@@ -429,7 +429,7 @@ class RoFormerIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
...@@ -443,7 +443,7 @@ class RoFormerOutput(nn.Module): ...@@ -443,7 +443,7 @@ class RoFormerOutput(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
...@@ -687,7 +687,7 @@ class RoFormerOnlyMLMHead(nn.Module): ...@@ -687,7 +687,7 @@ class RoFormerOnlyMLMHead(nn.Module):
super().__init__() super().__init__()
self.predictions = RoFormerLMPredictionHead(config) self.predictions = RoFormerLMPredictionHead(config)
def forward(self, sequence_output): def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
......
...@@ -924,7 +924,7 @@ class SEWDIntermediate(nn.Module): ...@@ -924,7 +924,7 @@ class SEWDIntermediate(nn.Module):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment