Unverified Commit d7754c43 authored by Ryan Chan's avatar Ryan Chan Committed by GitHub
Browse files

Type hints MCTCT (#19618)



* add type hints to mctct

* run auto style corrections

* change torch.bool to bool#

* Update src/transformers/models/mctct/modeling_mctct.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove optional tags for attention_mask and head_mask'

* fix optional tags'

* Update src/transformers/models/mctct/modeling_mctct.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 8aad4363
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import random import random
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -566,13 +566,13 @@ class MCTCTEncoder(MCTCTPreTrainedModel): ...@@ -566,13 +566,13 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
def forward( def forward(
self, self,
input_features, input_features: torch.Tensor,
attention_mask, attention_mask: torch.Tensor,
head_mask, head_mask: torch.Tensor,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -680,13 +680,13 @@ class MCTCTModel(MCTCTPreTrainedModel): ...@@ -680,13 +680,13 @@ class MCTCTModel(MCTCTPreTrainedModel):
) )
def forward( def forward(
self, self,
input_features, input_features: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -751,14 +751,14 @@ class MCTCTForCTC(MCTCTPreTrainedModel): ...@@ -751,14 +751,14 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
) )
def forward( def forward(
self, self,
input_features, input_features: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
): ) -> Union[Tuple, CausalLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
...@@ -783,7 +783,6 @@ class MCTCTForCTC(MCTCTPreTrainedModel): ...@@ -783,7 +783,6 @@ class MCTCTForCTC(MCTCTPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size: if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment