"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9f8fa4e9730b8e658bcd5625610cc70f3a019818"
Unverified Commit cd4c5c90 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF XLA greedy generation (#15786)



* First attempt at TF XLA generation

* Fix comments

* Update XLA greedy generate with direct XLA calls

* Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy

* Handle position_ids correctly

* make xla generate work for non xla case

* force using xla generate

* refactor

* more fixes

* finish cleaning

* finish

* finish

* clean gpt2 tests

* add gpt2 tests

* correct more cases

* up

* finish

* finish

* more fixes

* flake 8 stuff

* final rag fix

* Update src/transformers/models/rag/modeling_tf_rag.py

* finish t5 as well

* finish

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e5bc438c
...@@ -260,7 +260,6 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): ...@@ -260,7 +260,6 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
return tf.convert_to_tensor(token_penalties, dtype=tf.float32) return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores) score_penalties = self._create_score_penalties(input_ids, scores)
scores = tf.math.multiply(scores, score_penalties) scores = tf.math.multiply(scores, score_penalties)
......
...@@ -1484,9 +1484,12 @@ class TFGenerationMixin: ...@@ -1484,9 +1484,12 @@ class TFGenerationMixin:
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
# 3. Prepare other model kwargs # 3. Prepare other model kwargs
model_kwargs["output_attentions"] = output_attentions if output_attentions is not None:
model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["output_attentions"] = output_attentions
model_kwargs["use_cache"] = use_cache if output_hidden_states is not None:
model_kwargs["output_hidden_states"] = output_hidden_states
if use_cache is not None:
model_kwargs["use_cache"] = use_cache
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
...@@ -1533,7 +1536,6 @@ class TFGenerationMixin: ...@@ -1533,7 +1536,6 @@ class TFGenerationMixin:
raise ValueError( raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
) )
# 8. run greedy search # 8. run greedy search
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
...@@ -1545,7 +1547,6 @@ class TFGenerationMixin: ...@@ -1545,7 +1547,6 @@ class TFGenerationMixin:
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**model_kwargs, **model_kwargs,
) )
elif is_sample_gen_mode: elif is_sample_gen_mode:
# 8. prepare logits warper # 8. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
...@@ -1571,15 +1572,13 @@ class TFGenerationMixin: ...@@ -1571,15 +1572,13 @@ class TFGenerationMixin:
**model_kwargs, **model_kwargs,
) )
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
def _prepare_attention_mask_for_generation( def _prepare_attention_mask_for_generation(
self, self,
input_ids: tf.Tensor, input_ids: tf.Tensor,
pad_token_id: int, pad_token_id: int,
) -> tf.Tensor: ) -> tf.Tensor:
# prepare `attention_mask` if not passed # prepare `attention_mask` if not passed
if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()): if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id):
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
else: else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32) return tf.ones(input_ids.shape[:2], dtype=tf.int32)
...@@ -1717,6 +1716,14 @@ class TFGenerationMixin: ...@@ -1717,6 +1716,14 @@ class TFGenerationMixin:
return model_kwargs return model_kwargs
def _update_model_kwargs_for_xla_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], current_pos: tf.Tensor, max_length: int
) -> Dict[str, Any]:
raise NotImplementedError(
f"{self.__class__} is not compileable with XLA at the moment. You should implement a "
"`_update_model_kwargs_for_xla_generation` in the respective modeling file for XLA-compatible generation."
)
def _get_logits_warper( def _get_logits_warper(
self, self,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -1773,7 +1780,7 @@ class TFGenerationMixin: ...@@ -1773,7 +1780,7 @@ class TFGenerationMixin:
processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None: if bad_words_ids is not None:
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1: if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id)) processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
return processors return processors
...@@ -1858,7 +1865,8 @@ class TFGenerationMixin: ...@@ -1858,7 +1865,8 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```""" ```"""
# init values
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
...@@ -1871,94 +1879,153 @@ class TFGenerationMixin: ...@@ -1871,94 +1879,153 @@ class TFGenerationMixin:
return_dict_in_generate = ( return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
use_xla = not tf.executing_eagerly()
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length
batch_size, seq_length = input_ids.shape
# initialize `generated`, `finished_sequences`, and `current_pos`
generated = tf.TensorArray(
element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
# init attention / hidden states / scores tuples # write prompt to generated
scores = () if (return_dict_in_generate and output_scores) else None for i in range(seq_length):
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None generated = generated.write(i, input_ids[:, i])
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder: # 4. define "xla-compile-able" stop-condition and auto-regressive function
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None # define condition fn
encoder_hidden_states = ( def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None """state termination condition fn."""
) return ~tf.reduce_all(finished_sequences)
# keep track of which sequences are already finished # define condition fn
unfinished_sequences = tf.ones_like(input_ids[:, 0]) def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
cur_len = input_ids.shape[-1] """state update fn."""
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# prepare model inputs # forward pass to get next token logits
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self( outputs = self(
**model_inputs, **model_inputs,
return_dict=True, return_dict=True,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
next_token_logits = outputs.logits[:, -1]
next_token_logits = outputs.logits[:, -1, :]
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores += (next_token_logits,) scores.append(next_token_logits)
if output_attentions: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions += ( decoder_attentions.append(outputs.decoder_attentions)
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) elif output_attentions and not self.config.is_encoder_decoder:
) decoder_attentions.append(outputs.attentions)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,) cross_attentions.append(outputs.cross_attentions)
if output_hidden_states: if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states += ( decoder_hidden_states.append(outputs.decoder_hidden_states)
(outputs.decoder_hidden_states,) elif output_hidden_states and self.config.is_encoder_decoder:
if self.config.is_encoder_decoder decoder_hidden_states.append(outputs.hidden_states)
else (outputs.hidden_states,)
)
# pre-process distribution # pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors need to be adapted
# to be XLA compatible
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits) next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax # argmax
next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32) next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
# finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `current_pos`
generated = generated.write(current_pos[0], next_tokens)
next_tokens = next_tokens[:, None]
current_pos += 1
# update model_kwargs
if use_xla:
model_kwargs = self._update_model_kwargs_for_xla_generation(
outputs, model_kwargs, current_pos, max_length
)
else:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# update generated ids, model inputs, and length for next step next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) next_tokens = tf.transpose(next_tokens[: current_pos[0]])
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished return generated, finished_sequences, next_tokens, current_pos, model_kwargs
if eos_token_id is not None:
eos_in_sents = next_tokens == eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
)
# unfinished_sequences is set to zero if eos in sentence # 5. run generation
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos # 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, input_ids, current_pos, model_kwargs
)
# stop when each sentence is finished, or if we exceed the maximum length # 2-to-n generation steps can then be run in autoregressive fashion
if tf.math.reduce_max(unfinished_sequences) == 0: # only in case 1st generation step does NOT yield EOS token though
break if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
maximum_iterations = max_length - seq_length - 1
generated, _, _, current_pos, _ = tf.while_loop(
greedy_search_cond_fn,
greedy_search_body_fn,
(generated, finished_sequences, next_tokens, current_pos, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, : current_pos[0]]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
scores = tuple(scores) if scores is not None else None
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFGreedySearchEncoderDecoderOutput( return TFGreedySearchEncoderDecoderOutput(
sequences=input_ids, sequences=output_ids,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1968,13 +2035,13 @@ class TFGenerationMixin: ...@@ -1968,13 +2035,13 @@ class TFGenerationMixin:
) )
else: else:
return TFGreedySearchDecoderOnlyOutput( return TFGreedySearchDecoderOnlyOutput(
sequences=input_ids, sequences=output_ids,
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
return input_ids return output_ids
def sample( def sample(
self, self,
......
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
...@@ -851,7 +853,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -851,7 +853,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change # tests will need to be fixed after the change
...@@ -859,7 +861,81 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -859,7 +861,81 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache} # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids = None
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"use_cache": use_cache,
}
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
attention_mask = model_kwargs.pop("attention_mask")
batch_size = attention_mask.shape[0]
if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0].shape[3] - 1
padding_values = np.zeros((5, 2), dtype=np.int32)
padding_values[3, 1] = num_padding_values
padding_values = tf.constant(padding_values)
new_past = list(past)
for i in range(len(past)):
new_past[i] = tf.pad(past[i], padding_values)
# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
attention_mask = tf.concat(
[
attention_mask,
tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
],
axis=1,
)
else:
new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
)
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)
# set `attention_mask` and `past`
model_kwargs["attention_mask"] = attention_mask
model_kwargs["past"] = tuple(new_past)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -1309,9 +1309,13 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1309,9 +1309,13 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
) )
# TODO(Patrick) clean-up once generate is fully cleaned up
model_kwargs["attention_mask"] = context_attention_mask model_kwargs["attention_mask"] = context_attention_mask
# TODO(Patrick) remove once generate is fully cleaned up
if model_kwargs.get("encoder_attentions", None) is None:
model_kwargs.pop("encoder_attentions", None)
if model_kwargs.get("encoder_hidden_states", None) is None:
model_kwargs.pop("encoder_hidden_states", None)
model_kwargs.pop("output_hidden_states", None) model_kwargs.pop("output_hidden_states", None)
model_kwargs.pop("output_attentions", None) model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None) model_kwargs.pop("output_scores", None)
......
...@@ -21,7 +21,9 @@ import math ...@@ -21,7 +21,9 @@ import math
import warnings import warnings
from typing import Tuple from typing import Tuple
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
...@@ -1545,6 +1547,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1545,6 +1547,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
input_ids, input_ids,
past=None, past=None,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
use_cache=None, use_cache=None,
...@@ -1562,11 +1565,76 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1562,11 +1565,76 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"past_key_values": past, "past_key_values": past,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]
if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0][0].shape[2] - 1
padding_values = np.zeros((4, 2), dtype=np.int32)
padding_values[2, 1] = num_padding_values
padding_values = tf.constant(padding_values)
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)
# set `attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels) return self._shift_right(labels)
......
...@@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
else: else:
model.gradient_checkpointing_disable() model.gradient_checkpointing_disable()
model.to(torch_device) model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# The dog
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device)
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [ expected_output_ids = [
464, 464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,
3290, ]
373, # fmt: on
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
if verify_outputs: if verify_outputs:
self.assertListEqual(output_ids[0].tolist(), expected_output_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
......
...@@ -294,6 +294,21 @@ class TFGPT2ModelTester: ...@@ -294,6 +294,21 @@ class TFGPT2ModelTester:
result = model(inputs) result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def create_and_check_gpt2_double_head( def create_and_check_gpt2_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
): ):
...@@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC ...@@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs) self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
def test_gpt2_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
def test_gpt2_double_head(self): def test_gpt2_double_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs) self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
...@@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
# fmt: on # fmt: on
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
...@@ -227,6 +227,23 @@ class TFT5ModelTester: ...@@ -227,6 +227,23 @@ class TFT5ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFT5ForConditionalGeneration(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs (config, input_ids, input_mask, token_labels) = config_and_inputs
...@@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase): class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = model.generate(input_ids)
output_ids_xla = xla_generate(input_ids)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
expected_output_string = ["Heute ist ein schöner Tag."]
self.assertListEqual(expected_output_string, output_strings)
self.assertListEqual(expected_output_string, output_strings_xla)
@slow @slow
def test_greedy_generate(self): def test_greedy_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small") model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment