Unverified Commit 5e636eee authored by Partho's avatar Partho Committed by GitHub
Browse files

Add type hints for PyTorch UniSpeech, MPNet and Nystromformer (#19039)



* added type hints pytorch unispeech

* added type hints pytorch  MPNet

* added type hints nystromformer

* resolved copy inconsistencies

* make fix-copies
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 658010c7
...@@ -563,11 +563,11 @@ class Data2VecAudioEncoder(nn.Module): ...@@ -563,11 +563,11 @@ class Data2VecAudioEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
...@@ -618,7 +618,12 @@ class HubertEncoderLayerStableLayerNorm(nn.Module): ...@@ -618,7 +618,12 @@ class HubertEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = HubertFeedForward(config) self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
...@@ -649,11 +654,11 @@ class HubertEncoder(nn.Module): ...@@ -649,11 +654,11 @@ class HubertEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
...@@ -323,12 +323,12 @@ class MPNetEncoder(nn.Module): ...@@ -323,12 +323,12 @@ class MPNetEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=False, return_dict: bool = False,
**kwargs, **kwargs,
): ):
position_bias = self.compute_position_bias(hidden_states) position_bias = self.compute_position_bias(hidden_states)
......
...@@ -354,12 +354,12 @@ class NystromformerEncoder(nn.Module): ...@@ -354,12 +354,12 @@ class NystromformerEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
...@@ -655,7 +655,12 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module): ...@@ -655,7 +655,12 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechFeedForward(config) self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
...@@ -686,11 +691,11 @@ class UniSpeechEncoder(nn.Module): ...@@ -686,11 +691,11 @@ class UniSpeechEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
...@@ -669,7 +669,12 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): ...@@ -669,7 +669,12 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechSatFeedForward(config) self.feed_forward = UniSpeechSatFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
...@@ -700,11 +705,11 @@ class UniSpeechSatEncoder(nn.Module): ...@@ -700,11 +705,11 @@ class UniSpeechSatEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
...@@ -704,7 +704,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): ...@@ -704,7 +704,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = Wav2Vec2FeedForward(config) self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
...@@ -734,11 +739,11 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -734,11 +739,11 @@ class Wav2Vec2Encoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment