"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a3034c7004db43d1082babb7da8606f3676d38b9"
Unverified Commit 781af736 authored by akashe's avatar akashe Committed by GitHub
Browse files

added typehints for RAG pytorch models (#16416)

parent 5b40a37b
...@@ -767,25 +767,25 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -767,25 +767,25 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
context_input_ids=None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores=None, doc_scores: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_retrieved=None, output_retrieved: Optional[bool] = None,
exclude_bos_score=None, exclude_bos_score: Optional[bool] = None,
reduce_loss=None, reduce_loss: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
n_docs=None, n_docs: Optional[int] = None,
**kwargs # needs kwargs for generation **kwargs # needs kwargs for generation
): ) -> RetrievAugLMMarginOutput:
r""" r"""
exclude_bos_score (`bool`, *optional*): exclude_bos_score (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
...@@ -910,15 +910,15 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -910,15 +910,15 @@ class RagSequenceForGeneration(RagPreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores=None, doc_scores: Optional[torch.FloatTensor] = None,
do_deduplication=None, # defaults to True do_deduplication: Optional[bool] = None, # defaults to True
num_return_sequences=None, # defaults to 1 num_return_sequences: Optional[int] = None, # defaults to 1
num_beams=None, # defaults to 1 num_beams: Optional[int] = None, # defaults to 1
n_docs=None, n_docs: Optional[int] = None,
**model_kwargs **model_kwargs
): ) -> torch.LongTensor:
""" """
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]` Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
documentation for more information on how to set other generate input parameters. documentation for more information on how to set other generate input parameters.
...@@ -1234,25 +1234,25 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1234,25 +1234,25 @@ class RagTokenForGeneration(RagPreTrainedModel):
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
context_input_ids=None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores=None, doc_scores: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
output_retrieved=None, output_retrieved: Optional[bool] = None,
do_marginalize=None, do_marginalize: Optional[bool] = None,
reduce_loss=None, reduce_loss: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
n_docs=None, n_docs: Optional[int] = None,
**kwargs # needs kwargs for generation **kwargs # needs kwargs for generation
): ) -> RetrievAugLMMarginOutput:
r""" r"""
do_marginalize (`bool`, *optional*): do_marginalize (`bool`, *optional*):
If `True`, the logits are marginalized over all documents by making use of If `True`, the logits are marginalized over all documents by making use of
...@@ -1377,27 +1377,27 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1377,27 +1377,27 @@ class RagTokenForGeneration(RagPreTrainedModel):
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
doc_scores=None, doc_scores: Optional[torch.FloatTensor] = None,
max_length=None, max_length: Optional[int] = None,
min_length=None, min_length: Optional[int] = None,
early_stopping=None, early_stopping: Optional[bool] = None,
use_cache=None, use_cache: Optional[bool] = None,
num_beams=None, num_beams: Optional[int] = None,
num_beam_groups=None, num_beam_groups: Optional[int] = None,
diversity_penalty=None, diversity_penalty: Optional[float] = None,
bos_token_id=None, bos_token_id: Optional[int] = None,
pad_token_id=None, pad_token_id: Optional[int] = None,
eos_token_id=None, eos_token_id: Optional[int] = None,
length_penalty=None, length_penalty: Optional[float] = None,
no_repeat_ngram_size=None, no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size=None, encoder_no_repeat_ngram_size: Optional[int] = None,
repetition_penalty=None, repetition_penalty: Optional[float] = None,
bad_words_ids=None, bad_words_ids: Optional[List[List[int]]] = None,
num_return_sequences=None, num_return_sequences: Optional[int] = None,
decoder_start_token_id=None, decoder_start_token_id: Optional[int] = None,
n_docs=None, n_docs: Optional[int] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
...@@ -1406,7 +1406,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1406,7 +1406,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values: Optional[bool] = None, remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs **model_kwargs
): ) -> torch.LongTensor:
""" """
Implements RAG token decoding. Implements RAG token decoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment