Unverified Commit b4ddd267 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate refactor - XLA sample (#16713)

parent 02de7a8e
...@@ -346,6 +346,8 @@ class TFGenerationMixin: ...@@ -346,6 +346,8 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`]. A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
""" """
seed_generator = tf.random.Generator.from_non_deterministic_state()
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
""" """
Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method. Implement in subclasses of [`TFPreTrainedModel`] for custom behavior to prepare inputs in the generate method.
...@@ -585,6 +587,7 @@ class TFGenerationMixin: ...@@ -585,6 +587,7 @@ class TFGenerationMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
use_cache=use_cache, use_cache=use_cache,
seed=model_kwargs.pop("seed", None),
output_scores=output_scores, output_scores=output_scores,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1288,6 +1291,7 @@ class TFGenerationMixin: ...@@ -1288,6 +1291,7 @@ class TFGenerationMixin:
attention_mask=None, attention_mask=None,
decoder_start_token_id=None, decoder_start_token_id=None,
use_cache=None, use_cache=None,
seed=None,
output_scores=None, output_scores=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -1365,6 +1369,9 @@ class TFGenerationMixin: ...@@ -1365,6 +1369,9 @@ class TFGenerationMixin:
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding. speed up decoding.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`): output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details. returned tensors for more details.
...@@ -1590,6 +1597,7 @@ class TFGenerationMixin: ...@@ -1590,6 +1597,7 @@ class TFGenerationMixin:
max_length=max_length, max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
seed=seed,
output_scores=output_scores, output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**model_kwargs, **model_kwargs,
...@@ -1723,7 +1731,7 @@ class TFGenerationMixin: ...@@ -1723,7 +1731,7 @@ class TFGenerationMixin:
**model_kwargs, **model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]: ) -> Tuple[tf.Tensor, Dict[str, Any]]:
expanded_return_idx = tf.reshape( expanded_return_idx = tf.reshape(
tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1) tf.tile(tf.reshape(tf.range(input_ids.shape[0]), (-1, 1)), (1, expand_size)), (-1,)
) )
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0) input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
...@@ -2123,6 +2131,7 @@ class TFGenerationMixin: ...@@ -2123,6 +2131,7 @@ class TFGenerationMixin:
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
seed: Optional[Tuple[int, int]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,
...@@ -2149,6 +2158,9 @@ class TFGenerationMixin: ...@@ -2149,6 +2158,9 @@ class TFGenerationMixin:
The id of the *padding* token. The id of the *padding* token.
eos_token_id (`int`, *optional*): eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token. The id of the *end-of-sequence* token.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`): output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details. returned tensors for more details.
...@@ -2210,7 +2222,7 @@ class TFGenerationMixin: ...@@ -2210,7 +2222,7 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```""" ```"""
# init values # 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
...@@ -2224,97 +2236,155 @@ class TFGenerationMixin: ...@@ -2224,97 +2236,155 @@ class TFGenerationMixin:
return_dict_in_generate = ( return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
use_xla = not tf.executing_eagerly()
# init attention / hidden states / scores tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # 3. init tensors to use for "xla-compileable" generate function
if return_dict_in_generate and self.config.is_encoder_decoder: # define bsz, seq_length
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None batch_size, cur_len = input_ids.shape
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished # initialize `generated`, `finished_sequences`
unfinished_sequences = tf.ones_like(input_ids[:, 0]) generated = tf.TensorArray(
cur_len = input_ids.shape[-1] element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
while cur_len < max_length: # write prompt to generated
# prepare model inputs for i in range(cur_len):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) generated = generated.write(i, input_ids[:, i])
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
return ~tf.reduce_all(finished_sequences)
# forward pass to get next token def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# forward pass to get next token logits
outputs = self( outputs = self(
**model_inputs, **model_inputs,
return_dict=True, return_dict=True,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
next_token_logits = outputs.logits[:, -1]
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores += (next_token_scores,) scores.append(next_token_logits)
if output_attentions: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions += ( decoder_attentions.append(outputs.decoder_attentions)
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) elif output_attentions and not self.config.is_encoder_decoder:
) decoder_attentions.append(outputs.attentions)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,) cross_attentions.append(outputs.cross_attentions)
if output_hidden_states: if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states += ( decoder_hidden_states.append(outputs.decoder_hidden_states)
(outputs.decoder_hidden_states,) elif output_hidden_states and self.config.is_encoder_decoder:
if self.config.is_encoder_decoder decoder_hidden_states.append(outputs.hidden_states)
else (outputs.hidden_states,)
) # pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
# to be XLA compatible
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores)
# sample # sample
if seed is not None:
sample_seed = seed
else:
sample_seed = tf.cast(self.seed_generator.make_seeds(count=1)[:, 0], dtype=tf.int32)
next_tokens = tf.squeeze( next_tokens = tf.squeeze(
tf.random.categorical(logits=next_token_scores, num_samples=1, dtype=tf.int32), axis=1 tf.random.stateless_categorical(
logits=next_tokens_scores, num_samples=1, seed=sample_seed, dtype=tf.int32
),
axis=1,
) )
# finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update generated ids, model inputs, and length for next step # update `generated` and `cur_len`
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) generated = generated.write(cur_len, next_tokens)
model_kwargs = self._update_model_kwargs_for_generation( next_tokens = next_tokens[:, None]
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder cur_len += 1
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished # update model_kwargs
if eos_token_id is not None: if use_xla:
eos_in_sents = next_tokens == eos_token_id model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
# if sentence is unfinished and the token to add is eos else:
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( model_kwargs = self._update_model_kwargs_for_generation(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32) outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# unfinished_sequences is set to zero if eos in sentence next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos next_tokens = tf.transpose(next_tokens[:cur_len])
# stop when each sentence is finished, or if we exceed the maximum length return generated, finished_sequences, next_tokens, cur_len, model_kwargs
if tf.math.reduce_max(unfinished_sequences) == 0:
break # 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len - 1
generated, _, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
sample_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
scores = tuple(scores) if scores is not None else None
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFSampleEncoderDecoderOutput( return TFSampleEncoderDecoderOutput(
sequences=input_ids, sequences=output_ids,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2324,13 +2394,13 @@ class TFGenerationMixin: ...@@ -2324,13 +2394,13 @@ class TFGenerationMixin:
) )
else: else:
return TFSampleDecoderOnlyOutput( return TFSampleDecoderOnlyOutput(
sequences=input_ids, sequences=output_ids,
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
return input_ids return output_ids
def beam_search( def beam_search(
self, self,
...@@ -2575,8 +2645,8 @@ class TFGenerationMixin: ...@@ -2575,8 +2645,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
): ):
""" """
Beam Search termination condition function -- halts the generation loop if any of these conditions becomes Beam Search termination condition function -- halts the generation loop if any of these conditions becomes
...@@ -2604,8 +2674,8 @@ class TFGenerationMixin: ...@@ -2604,8 +2674,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
intermediary_running_sequences=None, intermediary_running_sequences=None,
): ):
""" """
...@@ -2781,8 +2851,8 @@ class TFGenerationMixin: ...@@ -2781,8 +2851,8 @@ class TFGenerationMixin:
next_sequences, next_sequences,
next_scores, next_scores,
next_is_sent_finished, next_is_sent_finished,
next_model_kwargs,
next_input_ids_length, next_input_ids_length,
next_model_kwargs,
) )
# 5. run generation # 5. run generation
...@@ -2799,8 +2869,8 @@ class TFGenerationMixin: ...@@ -2799,8 +2869,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
) = beam_search_body_fn( ) = beam_search_body_fn(
cur_len, cur_len,
running_sequences, running_sequences,
...@@ -2808,8 +2878,8 @@ class TFGenerationMixin: ...@@ -2808,8 +2878,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
) )
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
...@@ -2821,8 +2891,8 @@ class TFGenerationMixin: ...@@ -2821,8 +2891,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
): ):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop( cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
...@@ -2835,8 +2905,8 @@ class TFGenerationMixin: ...@@ -2835,8 +2905,8 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
model_kwargs,
input_ids_length, input_ids_length,
model_kwargs,
), ),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
......
...@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC ...@@ -447,19 +447,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@require_tf @require_tf
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_distilgpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
# The president of the United States, and the president of the United Kingdom, have been in the White
# fmt: off
expected_output_ids = [464, 1893, 286, 262, 1578, 1829, 11, 290, 262, 1893, 286, 262, 1578, 7526, 11, 423, 587, 287, 262, 2635]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow @slow
def test_lm_generate_greedy_distilgpt2_batch_special(self): def test_lm_generate_greedy_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2") model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
...@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -506,18 +493,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"temperature": 1.5, "temperature": 1.5,
"top_k": 500, "top_k": 500,
"top_p": 0.9, "top_p": 0.9,
"seed": [42, 0], # seed set -> deterministic sampling sequence -> deterministic generation
} }
# forces the generation to happen on CPU, to avoid GPU-related quirks # forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"): with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [ expected_output_string = [
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh", "Today is a beautiful day and we will make you feel very hot/terrific in all",
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say", "Yesterday was another solid success as news coverage became standard American domestic television hit.",
] ]
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
...@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -561,7 +548,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow @slow
def test_lm_generate_gpt2_xla(self): def test_lm_generate_gpt2_xla_greedy(self):
"""This test gives the exact same results as the non-xla test above""" """This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2") model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
...@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -574,3 +561,16 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids = xla_generate(input_ids, do_sample=False) output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla_sample(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# fmt: off
expected_output_ids = [464, 3290, 550, 284, 307, 4376, 287, 281, 4044, 1363, 329, 734, 812, 878, 852, 4376, 757, 329, 2267, 0]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[42, 0])
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
...@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): ...@@ -524,6 +524,35 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings) self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
# XLA reorder ops, which causes operations like FP matmul to have slightly different results, causing
# divergences in generate -- especially with sampling.
expected_output_string = ["Heute ist ein schöner Tag."]
expected_output_string_xla = ["Heute ist ein schöne Tage."]
# However, notice that the first tokens are the same, for the same seed
assert expected_output_string[0][:15] == expected_output_string_xla[0][:15]
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(expected_output_string, output_strings)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"):
xla_generate = tf.function(model.generate, jit_compile=True)
# seed set -> deterministic sampling sequence -> deterministic generation
output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow @slow
def test_sample_generate(self): def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small") model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
...@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase): ...@@ -540,16 +569,16 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
"temperature": 0.8, "temperature": 0.8,
"top_k": 500, "top_k": 500,
"top_p": 0.9, "top_p": 0.9,
"seed": [20, 0], # seed set -> deterministic sampling sequence -> deterministic generation
} }
# forces the generation to happen on CPU, to avoid GPU-related quirks # forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"): with tf.device(":/CPU:0"):
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"] expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings) self.assertListEqual(expected_output_string, output_strings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment