Unverified Commit 9b630168 authored by Debjit Bhowal's avatar Debjit Bhowal Committed by GitHub
Browse files

Added type hints for TF: rag model (#19284)

* Added type hints for TF: rag model

* TFModelInputType added in place of TF.Tensor

* reformatting by black
parent ac5ea74e
...@@ -16,13 +16,19 @@ ...@@ -16,13 +16,19 @@
"""TFRAG model implementation.""" """TFRAG model implementation."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, shape_list, unpack_inputs from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
shape_list,
unpack_inputs,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever from .retrieval_rag import RagRetriever
...@@ -491,7 +497,7 @@ class TFRagModel(TFRagPreTrainedModel): ...@@ -491,7 +497,7 @@ class TFRagModel(TFRagPreTrainedModel):
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None, question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None, generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None, retriever: Optional[RagRetriever] = None,
load_weight_prefix: Optional[str] = None, load_weight_prefix: Optional[str] = None,
**kwargs, **kwargs,
): ):
...@@ -538,22 +544,22 @@ class TFRagModel(TFRagPreTrainedModel): ...@@ -538,22 +544,22 @@ class TFRagModel(TFRagPreTrainedModel):
@replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores=None, doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids=None, context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask=None, context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_retrieved=None, output_retrieved: Optional[bool] = None,
n_docs=None, n_docs: Optional[int] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
**kwargs **kwargs
): ):
r""" r"""
...@@ -726,7 +732,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -726,7 +732,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None, question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None, generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None, retriever: Optional[RagRetriever] = None,
**kwargs, **kwargs,
): ):
assert config is not None or ( assert config is not None or (
...@@ -828,25 +834,25 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -828,25 +834,25 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
@replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores=None, doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids=None, context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask=None, context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_retrieved=None, output_retrieved: Optional[bool] = None,
n_docs=None, n_docs: Optional[int] = None,
do_marginalize=None, do_marginalize: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
reduce_loss=None, reduce_loss: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
**kwargs # needs kwargs for generation **kwargs # needs kwargs for generation
): ):
r""" r"""
...@@ -980,7 +986,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -980,7 +986,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def generate( def generate(
self, self,
input_ids: Optional[tf.Tensor] = None, input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
context_input_ids=None, context_input_ids=None,
context_attention_mask=None, context_attention_mask=None,
...@@ -1381,7 +1387,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1381,7 +1387,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None, question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None, generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None, retriever: Optional[RagRetriever] = None,
**kwargs, **kwargs,
): ):
assert config is not None or ( assert config is not None or (
...@@ -1425,27 +1431,27 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1425,27 +1431,27 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
@replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs=None, encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
doc_scores=None, doc_scores: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_input_ids=None, context_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
context_attention_mask=None, context_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_retrieved=None, output_retrieved: Optional[bool] = None,
n_docs=None, n_docs: Optional[int] = None,
exclude_bos_score=None, exclude_bos_score: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
reduce_loss=None, reduce_loss: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: bool = False,
**kwargs # needs kwargs for generation **kwargs # needs kwargs for generation
): ) -> Union[Tuple[tf.Tensor], TFRetrievAugLMMarginOutput]:
r""" r"""
exclude_bos_score (`bool`, *optional*): exclude_bos_score (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
...@@ -1657,7 +1663,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1657,7 +1663,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
def generate( def generate(
self, self,
input_ids: Optional[tf.Tensor] = None, input_ids: Optional[TFModelInputType] = None,
attention_mask: Optional[tf.Tensor] = None, attention_mask: Optional[tf.Tensor] = None,
context_input_ids=None, context_input_ids=None,
context_attention_mask=None, context_attention_mask=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment