"ml/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4fb47ed36862cead8d9455df8e34ae398d83e29f"
Unverified Commit 8f1f59ce authored by Ian C's avatar Ian C Committed by GitHub
Browse files

Add type hints for Whisper models (#20396)

* Initial commit

* Add type hints for two major classes

* Run make fixup

* Fix output type for Whisper

* Run isort to fix imports
parent 53357e81
...@@ -17,8 +17,9 @@ ...@@ -17,8 +17,9 @@
import math import math
import random import random
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -28,7 +29,13 @@ from ...modeling_tf_outputs import ( ...@@ -28,7 +29,13 @@ from ...modeling_tf_outputs import (
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, keras_serializable, unpack_inputs from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax from ...tf_utils import shape_list, stable_softmax
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig from .configuration_whisper import WhisperConfig
...@@ -1117,26 +1124,26 @@ class TFWhisperModel(TFWhisperPreTrainedModel): ...@@ -1117,26 +1124,26 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
return self.model.encoder return self.model.encoder
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_features=None, input_features: Optional[TFModelInputType] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids=None, decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
): ) -> Union[Tuple[tf.Tensor], TFSeq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1236,23 +1243,23 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua ...@@ -1236,23 +1243,23 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
@unpack_inputs @unpack_inputs
def call( def call(
self, self,
input_features=None, input_features: Optional[TFModelInputType] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_position_ids=None, decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask=None, decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Tuple[Union[np.ndarray, tf.Tensor]]] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
): ) -> Union[Tuple[tf.Tensor], TFSeq2SeqLMOutput]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment