Unverified Commit fa906a26 authored by Silver's avatar Silver Committed by GitHub
Browse files

Add `min_new_tokens` argument in generate() (implementation based on...

Add `min_new_tokens` argument in generate() (implementation based on `MinNewTokensLengthLogitsProcessor`) (#21044)

add a new parameter min_new_tokens for generate()
parent 125f1375
......@@ -75,7 +75,11 @@ class GenerationConfig(PushToHubMixin):
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated.
The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
`min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the
prompt.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
max_time(`float`, *optional*):
......@@ -207,6 +211,7 @@ class GenerationConfig(PushToHubMixin):
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0)
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None)
......
......@@ -48,6 +48,7 @@ from .logits_process import (
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
......@@ -822,6 +823,16 @@ class GenerationMixin:
and generation_config.min_length > 0
):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
)
)
if prefix_allowed_tokens_fn is not None:
processors.append(
PrefixConstrainedLogitsProcessor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment