Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e893b1ef
Unverified
Commit
e893b1ef
authored
Oct 18, 2023
by
Joao Gante
Committed by
GitHub
Oct 18, 2023
Browse files
Generate: improve docstrings for custom stopping criteria (#26863)
improve docstrings
parent
ef42cb62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
+10
-3
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+7
-2
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+3
-1
No files found.
src/transformers/generation/stopping_criteria.py
View file @
e893b1ef
...
@@ -23,7 +23,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
...
@@ -23,7 +23,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`, *optional*):
Additional stopping criteria specific kwargs.
Additional stopping criteria specific kwargs.
...
@@ -34,7 +35,11 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
...
@@ -34,7 +35,11 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
class
StoppingCriteria
(
ABC
):
class
StoppingCriteria
(
ABC
):
"""Abstract base class for all stopping criteria that can be applied during generation."""
"""Abstract base class for all stopping criteria that can be applied during generation.
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
output_scores=True` to `generate`.
"""
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
...
...
src/transformers/generation/utils.py
View file @
e893b1ef
...
@@ -1397,7 +1397,9 @@ class GenerationMixin:
...
@@ -1397,7 +1397,9 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`, *optional*):
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment