Unverified Commit 0ffa22f9 authored by Batese2001's avatar Batese2001 Committed by GitHub
Browse files

Added Type Hints for modeling_tf_encoder_decoder.py (#21673)



* Ran Black formatting

* Added imports and reformatted

* Update src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py

---------
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent aa3787c8
...@@ -19,13 +19,20 @@ import gc ...@@ -19,13 +19,20 @@ import gc
import os import os
import tempfile import tempfile
import warnings import warnings
from typing import Optional from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, unpack_inputs from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
get_initializer,
unpack_inputs,
)
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
...@@ -509,22 +516,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -509,22 +516,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
**kwargs, **kwargs,
): ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
r""" r"""
Returns: Returns:
...@@ -718,3 +725,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -718,3 +725,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))" " model.decoder.resize_token_embeddings(...))"
) )
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment