"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c45ef1c0d196ada1af50c72e0dbd2a8d310b59b2"
Unverified Commit e893b1ef authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: improve docstrings for custom stopping criteria (#26863)

improve docstrings
parent ef42cb62
......@@ -23,7 +23,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
[What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
kwargs (`Dict[str, Any]`, *optional*):
Additional stopping criteria specific kwargs.
......@@ -34,7 +35,11 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
"""Abstract base class for all stopping criteria that can be applied during generation.
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
output_scores=True` to `generate`.
"""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
......
......@@ -1397,7 +1397,9 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment