Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
781af736
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "980211a63a2a07057a97b1eb47b7b09d7eda2bcd"
Unverified
Commit
781af736
authored
Mar 29, 2022
by
akashe
Committed by
GitHub
Mar 29, 2022
Browse files
added typehints for RAG pytorch models (#16416)
parent
5b40a37b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
66 deletions
+66
-66
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+66
-66
No files found.
src/transformers/models/rag/modeling_rag.py
View file @
781af736
...
@@ -767,25 +767,25 @@ class RagSequenceForGeneration(RagPreTrainedModel):
...
@@ -767,25 +767,25 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@
replace_return_docstrings
(
output_type
=
RetrievAugLMMarginOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
RetrievAugLMMarginOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_outputs
=
None
,
encoder_outputs
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
decoder_input_ids
=
None
,
decoder_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
decoder_attention_mask
=
None
,
decoder_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
past_key_values
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
context_input_ids
=
None
,
context_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_attention_mask
=
None
,
context_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
doc_scores
=
None
,
doc_scores
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_retrieved
=
None
,
output_retrieved
:
Optional
[
bool
]
=
None
,
exclude_bos_score
=
None
,
exclude_bos_score
:
Optional
[
bool
]
=
None
,
reduce_loss
=
None
,
reduce_loss
:
Optional
[
bool
]
=
None
,
labels
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
n_docs
=
None
,
n_docs
:
Optional
[
int
]
=
None
,
**
kwargs
# needs kwargs for generation
**
kwargs
# needs kwargs for generation
):
)
->
RetrievAugLMMarginOutput
:
r
"""
r
"""
exclude_bos_score (`bool`, *optional*):
exclude_bos_score (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
...
@@ -910,15 +910,15 @@ class RagSequenceForGeneration(RagPreTrainedModel):
...
@@ -910,15 +910,15 @@ class RagSequenceForGeneration(RagPreTrainedModel):
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_input_ids
=
None
,
context_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_attention_mask
=
None
,
context_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
doc_scores
=
None
,
doc_scores
:
Optional
[
torch
.
FloatTensor
]
=
None
,
do_deduplication
=
None
,
# defaults to True
do_deduplication
:
Optional
[
bool
]
=
None
,
# defaults to True
num_return_sequences
=
None
,
# defaults to 1
num_return_sequences
:
Optional
[
int
]
=
None
,
# defaults to 1
num_beams
=
None
,
# defaults to 1
num_beams
:
Optional
[
int
]
=
None
,
# defaults to 1
n_docs
=
None
,
n_docs
:
Optional
[
int
]
=
None
,
**
model_kwargs
**
model_kwargs
):
)
->
torch
.
LongTensor
:
"""
"""
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
documentation for more information on how to set other generate input parameters.
documentation for more information on how to set other generate input parameters.
...
@@ -1234,25 +1234,25 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1234,25 +1234,25 @@ class RagTokenForGeneration(RagPreTrainedModel):
@
replace_return_docstrings
(
output_type
=
RetrievAugLMMarginOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
RetrievAugLMMarginOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_outputs
=
None
,
encoder_outputs
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
decoder_input_ids
=
None
,
decoder_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
decoder_attention_mask
=
None
,
decoder_attention_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
past_key_values
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
context_input_ids
=
None
,
context_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_attention_mask
=
None
,
context_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
doc_scores
=
None
,
doc_scores
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_retrieved
=
None
,
output_retrieved
:
Optional
[
bool
]
=
None
,
do_marginalize
=
None
,
do_marginalize
:
Optional
[
bool
]
=
None
,
reduce_loss
=
None
,
reduce_loss
:
Optional
[
bool
]
=
None
,
labels
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
n_docs
=
None
,
n_docs
:
Optional
[
int
]
=
None
,
**
kwargs
# needs kwargs for generation
**
kwargs
# needs kwargs for generation
):
)
->
RetrievAugLMMarginOutput
:
r
"""
r
"""
do_marginalize (`bool`, *optional*):
do_marginalize (`bool`, *optional*):
If `True`, the logits are marginalized over all documents by making use of
If `True`, the logits are marginalized over all documents by making use of
...
@@ -1377,27 +1377,27 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1377,27 +1377,27 @@ class RagTokenForGeneration(RagPreTrainedModel):
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_input_ids
=
None
,
context_input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
context_attention_mask
=
None
,
context_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
doc_scores
=
None
,
doc_scores
:
Optional
[
torch
.
FloatTensor
]
=
None
,
max_length
=
None
,
max_length
:
Optional
[
int
]
=
None
,
min_length
=
None
,
min_length
:
Optional
[
int
]
=
None
,
early_stopping
=
None
,
early_stopping
:
Optional
[
bool
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
num_beams
=
None
,
num_beams
:
Optional
[
int
]
=
None
,
num_beam_groups
=
None
,
num_beam_groups
:
Optional
[
int
]
=
None
,
diversity_penalty
=
None
,
diversity_penalty
:
Optional
[
float
]
=
None
,
bos_token_id
=
None
,
bos_token_id
:
Optional
[
int
]
=
None
,
pad_token_id
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
length_penalty
=
None
,
length_penalty
:
Optional
[
float
]
=
None
,
no_repeat_ngram_size
=
None
,
no_repeat_ngram_size
:
Optional
[
int
]
=
None
,
encoder_no_repeat_ngram_size
=
None
,
encoder_no_repeat_ngram_size
:
Optional
[
int
]
=
None
,
repetition_penalty
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
bad_words_ids
=
None
,
bad_words_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
num_return_sequences
=
None
,
num_return_sequences
:
Optional
[
int
]
=
None
,
decoder_start_token_id
=
None
,
decoder_start_token_id
:
Optional
[
int
]
=
None
,
n_docs
=
None
,
n_docs
:
Optional
[
int
]
=
None
,
prefix_allowed_tokens_fn
:
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]
=
None
,
prefix_allowed_tokens_fn
:
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
LogitsProcessorList
(),
logits_processor
:
Optional
[
LogitsProcessorList
]
=
LogitsProcessorList
(),
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
StoppingCriteriaList
(),
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
StoppingCriteriaList
(),
...
@@ -1406,7 +1406,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1406,7 +1406,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values
:
Optional
[
bool
]
=
None
,
remove_invalid_values
:
Optional
[
bool
]
=
None
,
exponential_decay_length_penalty
:
Optional
[
Tuple
[
Union
[
int
,
float
]]]
=
None
,
exponential_decay_length_penalty
:
Optional
[
Tuple
[
Union
[
int
,
float
]]]
=
None
,
**
model_kwargs
**
model_kwargs
):
)
->
torch
.
LongTensor
:
"""
"""
Implements RAG token decoding.
Implements RAG token decoding.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment