"vscode:/vscode.git/clone" did not exist on "2831826bc60ee11b86179252ffc8401cb03c5904"
Unverified Commit 5e636eee authored by Partho's avatar Partho Committed by GitHub
Browse files

Add type hints for PyTorch UniSpeech, MPNet and Nystromformer (#19039)



* added type hints pytorch unispeech

* added type hints pytorch  MPNet

* added type hints nystromformer

* resolved copy inconsistencies

* make fix-copies
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 658010c7
......@@ -563,11 +563,11 @@ class Data2VecAudioEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
......@@ -618,7 +618,12 @@ class HubertEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = HubertFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
......@@ -649,11 +654,11 @@ class HubertEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
......@@ -323,12 +323,12 @@ class MPNetEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
**kwargs,
):
position_bias = self.compute_position_bias(hidden_states)
......
......@@ -354,12 +354,12 @@ class NystromformerEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
......@@ -655,7 +655,12 @@ class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
......@@ -686,11 +691,11 @@ class UniSpeechEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
......@@ -669,7 +669,12 @@ class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = UniSpeechSatFeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
......@@ -700,11 +705,11 @@ class UniSpeechSatEncoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
......@@ -704,7 +704,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
......@@ -734,11 +739,11 @@ class Wav2Vec2Encoder(nn.Module):
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
hidden_states: torch.tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment