"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "0810a2648f18649f80102e04c45f26a0f1788efc"
Unverified Commit 49433310 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF can now accept custom logits processors (#21454)

parent e215e6de
......@@ -532,6 +532,7 @@ class TFGenerationMixin:
self,
input_ids: Optional[tf.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
seed=None,
**kwargs,
) -> Union[TFGenerateOutput, tf.Tensor]:
......@@ -560,6 +561,10 @@ class TFGenerationMixin:
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
......@@ -638,6 +643,8 @@ class TFGenerationMixin:
model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32)
# 3. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
......@@ -755,6 +762,7 @@ class TFGenerationMixin:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
logits_processor=logits_processor,
)
# 10. go into different generation modes
......@@ -1194,6 +1202,7 @@ class TFGenerationMixin:
self,
generation_config: GenerationConfig,
input_ids_seq_length: int,
logits_processor: Optional[TFLogitsProcessorList],
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
......@@ -1240,8 +1249,31 @@ class TFGenerationMixin:
)
if generation_config.forced_decoder_ids is not None:
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
processors = self._merge_criteria_processor_list(processors, logits_processor)
return processors
def _merge_criteria_processor_list(
self,
default_list: TFLogitsProcessorList,
custom_list: TFLogitsProcessorList,
) -> TFLogitsProcessorList:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `generate`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `generate` instead of using a custom {object_type}."
)
default_list.extend(custom_list)
return default_list
def greedy_search(
self,
input_ids: tf.Tensor,
......
......@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...generation import TFLogitsProcessorList
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
......@@ -1002,6 +1003,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
doc_scores=None,
n_docs=None,
generation_config=None,
logits_processor=TFLogitsProcessorList(),
**kwargs
):
"""
......@@ -1045,6 +1047,10 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`TFLogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
......@@ -1149,6 +1155,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
pre_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
logits_processor=logits_processor,
)
if generation_config.num_beams == 1:
......
......@@ -12,6 +12,8 @@ class GenerationIntegrationTestsMixin:
# To be populated by the child classes
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": None,
"LogitsProcessorList": None,
"MinLengthLogitsProcessor": None,
"create_tensor_fn": None,
"return_tensors": None,
}
......@@ -39,3 +41,23 @@ class GenerationIntegrationTestsMixin:
# however, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
model.generate(input_ids, **valid_model_kwargs)
def test_custom_logits_processor(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
logits_processor = logits_processor_list_cls()
logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
......@@ -25,7 +25,13 @@ from .test_framework_agnostic import GenerationIntegrationTestsMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
from transformers import (
TFAutoModelForCausalLM,
TFAutoModelForSeq2SeqLM,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
tf_top_k_top_p_filtering,
)
@require_tf
......@@ -132,6 +138,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
if is_tf_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
"LogitsProcessorList": TFLogitsProcessorList,
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
"create_tensor_fn": tf.convert_to_tensor,
"return_tensors": "tf",
}
......
......@@ -1797,12 +1797,15 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
if is_torch_available():
framework_dependent_parameters = {
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
"LogitsProcessorList": LogitsProcessorList,
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
"create_tensor_fn": torch.tensor,
"return_tensors": "pt",
}
@slow
def test_diverse_beam_search(self):
# PT-only test: TF doesn't have a diverse beam search implementation
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
......@@ -1836,6 +1839,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_greedy(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
......@@ -1862,6 +1866,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_sample(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
......@@ -1888,6 +1893,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
......@@ -1918,6 +1924,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_backward_compat_group_beam_search(self):
# PT-only test: TF doesn't have StoppingCriteria & group beam search
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
......@@ -1952,6 +1959,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_max_length_warning_if_different(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
......@@ -2035,6 +2043,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
......@@ -2048,6 +2057,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
def test_custom_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
......@@ -2070,7 +2080,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
def test_stop_sequence_stopping_criteria(self):
# PT-only test: TF doesn't have StoppingCriteria
prompt = """Hello I believe in"""
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
output = generator(prompt)
......@@ -2088,23 +2098,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = generator(prompt, stop_sequence=" number")
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
def test_custom_logits_processor(self):
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor)
bart_model.config.min_length = None
bart_model.generate(input_ids, logits_processor=logits_processor)
def test_max_new_tokens_encoder_decoder(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment