Unverified Commit d4d4447d authored by RafaelWO's avatar RafaelWO Committed by GitHub
Browse files

fixed prefix_allowed_tokens_fn docstring in generate() (#10862)

parent 7ef40120
......@@ -776,9 +776,9 @@ class GenerationMixin:
enabled.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
provided no constraint is applied. This function takes 2 arguments: the batch ID :obj:`batch_id` and
:obj:`input_ids`. It has to return a list with the allowed tokens for the next generation step
conditioned on the batch ID :obj:`batch_id` and the previously generated tokens :obj:`inputs_ids`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment