"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ab0ddc99e853c974949d823dbfaa732202696f3e"
Unverified Commit 9b630168 authored by Debjit Bhowal's avatar Debjit Bhowal Committed by GitHub
Browse files

Added type hints for TF: rag model (#19284)

* Added type hints for TF: rag model

* TFModelInputType added in place of TF.Tensor

* reformatting by black
parent ac5ea74e
......@@ -16,13 +16,19 @@
"""TFRAG model implementation."""
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, shape_list, unpack_inputs
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
shape_list,
unpack_inputs,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
......@@ -491,7 +497,7 @@ class TFRagModel(TFRagPreTrainedModel):
config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None,
retriever: Optional[RagRetriever] = None,
load_weight_prefix: Optional[str] = None,
**kwargs,
):
......@@ -538,22 +544,22 @@ class TFRagModel(TFRagPreTrainedModel):
@replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
past_key_values=None,
doc_scores=None,
context_input_ids=None,
context_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
n_docs=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_retrieved: Optional[bool] = None,
n_docs: Optional[int] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs
):
r"""
......@@ -726,7 +732,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None,
retriever: Optional[RagRetriever] = None,
**kwargs,
):
assert config is not None or (
......@@ -828,25 +834,25 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
@replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
doc_scores=None,
context_input_ids=None,
context_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
n_docs=None,
do_marginalize=None,
labels=None,
reduce_loss=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_retrieved: Optional[bool] = None,
n_docs: Optional[int] = None,
do_marginalize: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
reduce_loss: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs # needs kwargs for generation
):
r"""
......@@ -980,7 +986,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def generate(
self,
input_ids: Optional[tf.Tensor] = None,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[tf.Tensor] = None,
context_input_ids=None,
context_attention_mask=None,
......@@ -1381,7 +1387,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None,
retriever: Optional[RagRetriever] = None,
**kwargs,
):
assert config is not None or (
......@@ -1425,27 +1431,27 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
@replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
doc_scores=None,
context_input_ids=None,
context_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
n_docs=None,
exclude_bos_score=None,
labels=None,
reduce_loss=None,
return_dict=None,
training=False,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_retrieved: Optional[bool] = None,
n_docs: Optional[int] = None,
exclude_bos_score: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
reduce_loss: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs # needs kwargs for generation
):
) -> Union[Tuple[tf.Tensor], TFRetrievAugLMMarginOutput]:
r"""
exclude_bos_score (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
......@@ -1657,7 +1663,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
def generate(
self,
input_ids: Optional[tf.Tensor] = None,
input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[tf.Tensor] = None,
context_input_ids=None,
context_attention_mask=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment