"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ba52dec47f870ab713317d3e61ba209bb5800783"
Unverified Commit fa906a26 authored by Silver's avatar Silver Committed by GitHub
Browse files

Add `min_new_tokens` argument in generate() (implementation based on...

Add `min_new_tokens` argument in generate() (implementation based on `MinNewTokensLengthLogitsProcessor`) (#21044)

add a new parameter min_new_tokens for generate()
parent 125f1375
...@@ -75,7 +75,11 @@ class GenerationConfig(PushToHubMixin): ...@@ -75,7 +75,11 @@ class GenerationConfig(PushToHubMixin):
max_new_tokens (`int`, *optional*): max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 0): min_length (`int`, *optional*, defaults to 0):
The minimum length of the sequence to be generated. The minimum length of the sequence to be generated. Corresponds to the length of the input prompt +
`min_new_tokens`. In general, prefer the use of `min_new_tokens`, which ignores the number of tokens in the
prompt.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`): early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
max_time(`float`, *optional*): max_time(`float`, *optional*):
...@@ -207,6 +211,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -207,6 +211,7 @@ class GenerationConfig(PushToHubMixin):
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None) self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0) self.min_length = kwargs.pop("min_length", 0)
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False) self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None) self.max_time = kwargs.pop("max_time", None)
......
...@@ -48,6 +48,7 @@ from .logits_process import ( ...@@ -48,6 +48,7 @@ from .logits_process import (
LogitNormalization, LogitNormalization,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor, NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor, NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
...@@ -822,6 +823,16 @@ class GenerationMixin: ...@@ -822,6 +823,16 @@ class GenerationMixin:
and generation_config.min_length > 0 and generation_config.min_length > 0
): ):
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
)
)
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append( processors.append(
PrefixConstrainedLogitsProcessor( PrefixConstrainedLogitsProcessor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment