Unverified Commit 8f1f59ce authored by Ian C's avatar Ian C Committed by GitHub
Browse files

Add type hints for Whisper models (#20396)

* Initial commit

* Add type hints for two major classes

* Run make fixup

* Fix output type for Whisper

* Run isort to fix imports
parent 53357e81
......@@ -17,8 +17,9 @@
import math
import random
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
......@@ -28,7 +29,13 @@ from ...modeling_tf_outputs import (
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, keras_serializable, unpack_inputs
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
......@@ -1117,26 +1124,26 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
return self.model.encoder
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs
def call(
self,
input_features=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
input_features: Optional[TFModelInputType] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFSeq2SeqModelOutput]:
r"""
Returns:
......@@ -1236,23 +1243,23 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
@unpack_inputs
def call(
self,
input_features=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_position_ids=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
input_features: Optional[TFModelInputType] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment