"docs/vscode:/vscode.git/clone" did not exist on "ae710425d2d8edf4d197bf893b90ed0546054701"
Unverified Commit b4ddd267 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate refactor - XLA sample (#16713)

parent 02de7a8e
......@@ -346,6 +346,8 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
"""
seed_generator = tf.random.Generator.from_non_deterministic_state()
def prepare_inputs_for_generation(self, inputs, **kwargs):
"""
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
......@@ -585,6 +587,7 @@ class TFGenerationMixin:
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache,
seed=model_kwargs.pop("seed", None),
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
seed=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
......@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
seed=seed,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
......@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
expanded_return_idx = tf.reshape(
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1)
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
)
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
......@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
seed: Optional[Tuple[int, int]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
......@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
......@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
# init values
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
......@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
use_xla = not tf.executing_eagerly()
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length
batch_size, cur_len = input_ids.shape
# initialize `generated`, `finished_sequences`
generated = tf.TensorArray(
element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
# write prompt to generated
for i in range(cur_len):
generated = generated.write(i, input_ids[:, i])
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
return ~tf.reduce_all(finished_sequences)
# forward pass to get next token
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# forward pass to get next token logits
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_token_scores = logits_warper(input_ids, next_token_scores)
next_token_logits = outputs.logits[:, -1]
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if not use_xla and return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
scores.append(next_token_logits)
if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder:
decoder_attentions.append(outputs.attentions)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
cross_attentions.append(outputs.cross_attentions)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.decoder_hidden_states)
elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
# to be XLA compatible
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
# sample
if seed is not None:
sample_seed = seed
else:
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
next_tokens = tf.squeeze(
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1
tf.random.stateless_categorical(
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32
),
axis=1,
)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len`
generated = generated.write(cur_len, next_tokens)
next_tokens = next_tokens[:, None]
cur_len += 1
# update generated ids, model inputs, and length for next step
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
# update model_kwargs
if use_xla:
model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
else:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
eos_in_sents = next_tokens == eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs
)
# unfinished_sequences is set to zero if eos in sentence
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len - 1
generated, _, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
sample_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
# stop when each sentence is finished, or if we exceed the maximum length
if tf.math.reduce_max(unfinished_sequences) == 0:
break
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
scores = tuple(scores) if scores is not None else None
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFSampleEncoderDecoderOutput(
sequences=input_ids,
sequences=output_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
......@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
)
else:
return TFSampleDecoderOnlyOutput(
sequences=input_ids,
sequences=output_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
return output_ids
def beam_search(
self,
......@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
):
"""
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
......@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
intermediary_running_sequences=None,
):
"""
......@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
next_sequences,
next_scores,
next_is_sent_finished,
next_model_kwargs,
next_input_ids_length,
next_model_kwargs,
)
# 5. run generation
......@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
) = beam_search_body_fn(
cur_len,
running_sequences,
......@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
......@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
):
maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
......@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
model_kwargs,
input_ids_length,
model_kwargs,
),
maximum_iterations=maximum_iterations,
)
......
......@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@require_tf
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_distilgpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
# The president of the United States, and the president of the United Kingdom, have been in the White
# fmt: off
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_greedy_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
......@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"temperature": 1.5,
"top_k": 500,
"top_p": 0.9,
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
"Today is a beautiful day and we will make you feel very hot/terrific in all",
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
]
self.assertListEqual(output_strings, expected_output_string)
......@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla(self):
def test_lm_generate_gpt2_xla_greedy(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
......@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla_sample(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
......@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string = ["Heute ist ein schöner Tag."]
expected_output_string_xla = ["Heute ist ein schöne Tage."]
# However, notice that the first tokens are the same, for the same seed
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow
def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
......@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"temperature": 0.8,
"top_k": 500,
"top_p": 0.9,
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
}
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment