Unverified Commit 781af736 authored by akashe's avatar akashe Committed by GitHub
Browse files

added typehints for RAG pytorch models (#16416)

parent 5b40a37b
......@@ -767,25 +767,25 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
past_key_values=None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
exclude_bos_score=None,
reduce_loss=None,
labels=None,
n_docs=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_retrieved: Optional[bool] = None,
exclude_bos_score: Optional[bool] = None,
reduce_loss: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
n_docs: Optional[int] = None,
**kwargs # needs kwargs for generation
):
) -> RetrievAugLMMarginOutput:
r"""
exclude_bos_score (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
......@@ -910,15 +910,15 @@ class RagSequenceForGeneration(RagPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
do_deduplication=None, # defaults to True
num_return_sequences=None, # defaults to 1
num_beams=None, # defaults to 1
n_docs=None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
do_deduplication: Optional[bool] = None, # defaults to True
num_return_sequences: Optional[int] = None, # defaults to 1
num_beams: Optional[int] = None, # defaults to 1
n_docs: Optional[int] = None,
**model_kwargs
):
) -> torch.LongTensor:
"""
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
documentation for more information on how to set other generate input parameters.
......@@ -1234,25 +1234,25 @@ class RagTokenForGeneration(RagPreTrainedModel):
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
past_key_values=None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
do_marginalize=None,
reduce_loss=None,
labels=None,
n_docs=None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_retrieved: Optional[bool] = None,
do_marginalize: Optional[bool] = None,
reduce_loss: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
n_docs: Optional[int] = None,
**kwargs # needs kwargs for generation
):
) -> RetrievAugLMMarginOutput:
r"""
do_marginalize (`bool`, *optional*):
If `True`, the logits are marginalized over all documents by making use of
......@@ -1377,27 +1377,27 @@ class RagTokenForGeneration(RagPreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
max_length=None,
min_length=None,
early_stopping=None,
use_cache=None,
num_beams=None,
num_beam_groups=None,
diversity_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
encoder_no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores: Optional[torch.FloatTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
early_stopping: Optional[bool] = None,
use_cache: Optional[bool] = None,
num_beams: Optional[int] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[List[List[int]]] = None,
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
n_docs: Optional[int] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
......@@ -1406,7 +1406,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs
):
) -> torch.LongTensor:
"""
Implements RAG token decoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment