Unverified Commit cca6e6fe authored by Matt's avatar Matt Committed by GitHub
Browse files

Cast TF generate() inputs (#19232)



* Just stick a couple of casts into generate()

* Cast decoder_input_ids too

* Don't accidentally cast floats

* Move to _generate()

* Move to after input validation
Co-authored-by: default avatarYour Name <you@example.com>
parent 01eb34ab
......@@ -1533,11 +1533,35 @@ class TFGenerationMixin:
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate the `.generate()` call
self._validate_model_class()
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined
# 1. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
if input_ids is not None:
if isinstance(input_ids, tf.Tensor) and input_ids.dtype.is_floating:
pass
elif isinstance(input_ids, np.ndarray) and np.issubdtype(input_ids.dtype, np.floating):
pass
else:
input_ids = tf.cast(input_ids, tf.int32)
if attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int32)
if "decoder_input_ids" in model_kwargs:
if (
isinstance(model_kwargs["decoder_input_ids"], tf.Tensor)
and model_kwargs["decoder_input_ids"].dtype.is_floating
):
pass
elif isinstance(model_kwargs["decoder_input_ids"], np.ndarray) and np.issubdtype(
model_kwargs["decoder_input_ids"].dtype, np.floating
):
pass
else:
model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32)
# 2. Set generation parameters if not already defined
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
......@@ -1582,12 +1606,12 @@ class TFGenerationMixin:
"The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())"
)
# 2. Define model inputs
# 3. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = shape_list(input_ids)[0]
# 3. Prepare other model kwargs
# 4. Prepare other model kwargs
if output_attentions is not None:
model_kwargs["output_attentions"] = output_attentions
if output_hidden_states is not None:
......@@ -1613,7 +1637,7 @@ class TFGenerationMixin:
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
# 4. Prepare model inputs which will be used for auto-regressive generation
# 5. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
......@@ -1625,7 +1649,7 @@ class TFGenerationMixin:
model_kwargs=model_kwargs,
)
# 5. Prepare `max_length` depending on other stopping criteria.
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
warnings.warn(
......@@ -1661,13 +1685,13 @@ class TFGenerationMixin:
"`max_new_tokens`."
)
# 6. determine generation mode
# 7. determine generation mode
# TODO(Matt, Joao, Patrick) - add more use cases here
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
# 7. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
......@@ -1679,7 +1703,7 @@ class TFGenerationMixin:
forced_eos_token_id=forced_eos_token_id,
)
# 8. go into different generation modes
# 9. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
......@@ -1697,10 +1721,10 @@ class TFGenerationMixin:
**model_kwargs,
)
elif is_sample_gen_mode:
# 9. prepare logits warper
# 10. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
# 10. expand input_ids with `num_return_sequences` additional sequences per batch
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
......@@ -1708,7 +1732,7 @@ class TFGenerationMixin:
**model_kwargs,
)
# 11. run sample
# 12. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
......@@ -1729,7 +1753,7 @@ class TFGenerationMixin:
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
)
# 9. broadcast inputs to the desired number of beams
# 10. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
if "encoder_outputs" in model_kwargs:
......@@ -1742,7 +1766,7 @@ class TFGenerationMixin:
model_kwargs["attention_mask"], num_beams=num_beams
)
# 10. run beam search
# 11. run beam search
return self.beam_search(
input_ids,
max_length=max_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment