Unverified Commit 975dd2bb authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: GPT-2 generation supports left-padding (#17426)

* TF GPT-2 now properly works with left padding

* throw a warning when eos token == pad token and there is no attention mask
parent c1a13861
...@@ -1498,8 +1498,14 @@ class TFGenerationMixin: ...@@ -1498,8 +1498,14 @@ class TFGenerationMixin:
) )
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
if attention_mask is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence") logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
pad_token_id = eos_token_id pad_token_id = eos_token_id
if min_length is not None and min_length > max_length: if min_length is not None and min_length > max_length:
raise ValueError( raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
...@@ -1525,7 +1531,9 @@ class TFGenerationMixin: ...@@ -1525,7 +1531,9 @@ class TFGenerationMixin:
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
)
# 4. Prepare model inputs which will be used for auto-regressive generation # 4. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -1653,12 +1661,17 @@ class TFGenerationMixin: ...@@ -1653,12 +1661,17 @@ class TFGenerationMixin:
def _prepare_attention_mask_for_generation( def _prepare_attention_mask_for_generation(
self, self,
inputs: tf.Tensor, inputs: tf.Tensor,
pad_token_id: int, pad_token_id: Optional[int],
eos_token_id: Optional[int],
) -> tf.Tensor: ) -> tf.Tensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64) is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id) is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
# Check if input is input_ids and padded -> only then is attention_mask defined # Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs: if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32) return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
else: else:
return tf.ones(inputs.shape[:2], dtype=tf.int32) return tf.ones(inputs.shape[:2], dtype=tf.int32)
...@@ -1954,6 +1967,7 @@ class TFGenerationMixin: ...@@ -1954,6 +1967,7 @@ class TFGenerationMixin:
# 1. init greedy_search values # 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -1973,10 +1987,9 @@ class TFGenerationMixin: ...@@ -1973,10 +1987,9 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length batch_size, cur_len = input_ids.shape
batch_size, seq_length = input_ids.shape
# initialize `generated`, `finished_sequences`, and `current_pos` # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray( generated = tf.TensorArray(
element_shape=(batch_size,), element_shape=(batch_size,),
dtype=tf.int32, dtype=tf.int32,
...@@ -1984,25 +1997,26 @@ class TFGenerationMixin: ...@@ -1984,25 +1997,26 @@ class TFGenerationMixin:
size=max_length, size=max_length,
clear_after_read=False, clear_after_read=False,
) )
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated # write prompt to generated
for i in range(seq_length): for i in range(cur_len):
generated = generated.write(i, input_ids[:, i]) generated = generated.write(i, input_ids[:, i])
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
# 4. define "xla-compile-able" stop-condition and auto-regressive function # 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn # define condition fn
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
"""state termination condition fn.""" """state termination condition fn."""
return ~tf.reduce_all(finished_sequences) return ~tf.reduce_all(finished_sequences)
# define condition fn # define condition fn
def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
"""state update fn.""" """state update fn."""
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# forward pass to get next token logits # forward pass to get next token logits
outputs = self( outputs = self(
**model_inputs, **model_inputs,
...@@ -2029,13 +2043,8 @@ class TFGenerationMixin: ...@@ -2029,13 +2043,8 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states) decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution # pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors need to be adapted input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
# to be XLA compatible next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
# argmax # argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
...@@ -2047,16 +2056,14 @@ class TFGenerationMixin: ...@@ -2047,16 +2056,14 @@ class TFGenerationMixin:
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id) finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `current_pos` # update `generated` and `cur_len`
generated = generated.write(current_pos[0], next_tokens) generated = generated.write(cur_len, next_tokens)
next_tokens = next_tokens[:, None] next_tokens = next_tokens[:, None]
current_pos += 1 cur_len += 1
# update model_kwargs # update model_kwargs
if use_xla: if use_xla:
model_kwargs = self._update_model_kwargs_for_xla_generation( model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
outputs, model_kwargs, current_pos, max_length
)
else: else:
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
...@@ -2067,24 +2074,24 @@ class TFGenerationMixin: ...@@ -2067,24 +2074,24 @@ class TFGenerationMixin:
model_kwargs.pop("past", None) model_kwargs.pop("past", None)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
next_tokens = tf.transpose(next_tokens[: current_pos[0]]) next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, current_pos, model_kwargs return generated, finished_sequences, next_tokens, cur_len, model_kwargs
# 5. run generation # 5. run generation
# 1st generation step has to be run before to initialize `past` # 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn( generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, input_ids, current_pos, model_kwargs generated, finished_sequences, input_ids, cur_len, model_kwargs
) )
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
maximum_iterations = max_length - seq_length - 1 maximum_iterations = max_length - cur_len
generated, _, _, current_pos, _ = tf.while_loop( generated, _, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn, greedy_search_cond_fn,
greedy_search_body_fn, greedy_search_body_fn,
(generated, finished_sequences, next_tokens, current_pos, model_kwargs), (generated, finished_sequences, next_tokens, cur_len, model_kwargs),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
...@@ -2093,7 +2100,7 @@ class TFGenerationMixin: ...@@ -2093,7 +2100,7 @@ class TFGenerationMixin:
if not use_xla: if not use_xla:
# cut for backward compatibility # cut for backward compatibility
output_ids = output_ids[:, : current_pos[0]] output_ids = output_ids[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -2231,6 +2238,7 @@ class TFGenerationMixin: ...@@ -2231,6 +2238,7 @@ class TFGenerationMixin:
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
...@@ -2250,10 +2258,9 @@ class TFGenerationMixin: ...@@ -2250,10 +2258,9 @@ class TFGenerationMixin:
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length
batch_size, cur_len = input_ids.shape batch_size, cur_len = input_ids.shape
# initialize `generated`, `finished_sequences` # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray( generated = tf.TensorArray(
element_shape=(batch_size,), element_shape=(batch_size,),
dtype=tf.int32, dtype=tf.int32,
...@@ -2261,19 +2268,22 @@ class TFGenerationMixin: ...@@ -2261,19 +2268,22 @@ class TFGenerationMixin:
size=max_length, size=max_length,
clear_after_read=False, clear_after_read=False,
) )
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated # write prompt to generated
for i in range(cur_len): for i in range(cur_len):
generated = generated.write(i, input_ids[:, i]) generated = generated.write(i, input_ids[:, i])
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function # 4. define "xla-compile-able" stop-condition and auto-regressive function
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
return ~tf.reduce_all(finished_sequences) return ~tf.reduce_all(finished_sequences)
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# forward pass to get next token logits # forward pass to get next token logits
outputs = self( outputs = self(
**model_inputs, **model_inputs,
...@@ -2300,12 +2310,7 @@ class TFGenerationMixin: ...@@ -2300,12 +2310,7 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states) decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution # pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
# to be XLA compatible
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[:cur_len])
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len) next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
...@@ -2359,7 +2364,7 @@ class TFGenerationMixin: ...@@ -2359,7 +2364,7 @@ class TFGenerationMixin:
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len - 1 maximum_iterations = max_length - cur_len
generated, _, _, cur_len, _ = tf.while_loop( generated, _, _, cur_len, _ = tf.while_loop(
sample_cond_fn, sample_cond_fn,
sample_body_fn, sample_body_fn,
...@@ -2613,12 +2618,13 @@ class TFGenerationMixin: ...@@ -2613,12 +2618,13 @@ class TFGenerationMixin:
size=max_length, size=max_length,
clear_after_read=False, clear_after_read=False,
) )
for i in range(max_length): if pad_token_id: # ignores the cases when it is 0 or None
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) for i in range(max_length):
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams))) sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
intermediary_running_sequences = intermediary_running_sequences.write( running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2)) intermediary_running_sequences = intermediary_running_sequences.write(
) i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
)
# write prompt to running_sequences # write prompt to running_sequences
for i in range(cur_len): for i in range(cur_len):
...@@ -2699,9 +2705,7 @@ class TFGenerationMixin: ...@@ -2699,9 +2705,7 @@ class TFGenerationMixin:
(0, 0, cur_len - input_ids_length), (0, 0, cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length), (batch_size, num_beams, input_ids_length),
) )
model_inputs = self.prepare_inputs_for_generation( model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs
)
model_outputs = self( model_outputs = self(
**model_inputs, **model_inputs,
return_dict=True, return_dict=True,
......
...@@ -490,8 +490,8 @@ class GenerationMixin: ...@@ -490,8 +490,8 @@ class GenerationMixin:
def _prepare_attention_mask_for_generation( def _prepare_attention_mask_for_generation(
self, self,
inputs: torch.Tensor, inputs: torch.Tensor,
pad_token_id: int, pad_token_id: Optional[int],
eos_token_id: int, eos_token_id: Optional[int],
) -> torch.LongTensor: ) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
...@@ -1137,7 +1137,11 @@ class GenerationMixin: ...@@ -1137,7 +1137,11 @@ class GenerationMixin:
eos_token_id = self.config.decoder.eos_token_id eos_token_id = self.config.decoder.eos_token_id
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id pad_token_id = eos_token_id
......
...@@ -813,25 +813,21 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -813,25 +813,21 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 token_type_ids = kwargs.get("token_type_ids", None)
# tests will need to be fixed after the change
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
if token_type_ids is not None:
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left if attention_mask is not None and position_ids is None:
# for a future PR to not change too many things for now. position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) if past:
position_ids = None position_ids = tf.expand_dims(position_ids[:, -1], -1)
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
return { return {
"input_ids": inputs, "input_ids": inputs,
...@@ -839,6 +835,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -839,6 +835,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
"position_ids": position_ids, "position_ids": position_ids,
"past": past, "past": past,
"use_cache": use_cache, "use_cache": use_cache,
"token_type_ids": token_type_ids,
} }
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
......
...@@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"] sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = { generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
...@@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"repetition_penalty": 1.3, "repetition_penalty": 1.3,
} }
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [ expected_output_string = [
"Today is a beautiful day and I am so happy to be able take part in this amazing event.", "Today is a beautiful day and I am so happy to be able take part in this amazing event.",
"Yesterday was a very busy day for the first time since I started writing this post", "Yesterday was a very interesting time for the world to see how much of this is",
] ]
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
...@@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"] sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = { generation_kwargs = {
"do_sample": True, "do_sample": True,
...@@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
# forces the generation to happen on CPU, to avoid GPU-related quirks # forces the generation to happen on CPU, to avoid GPU-related quirks
with tf.device(":/CPU:0"): with tf.device(":/CPU:0"):
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [ expected_output_string = [
"Today is a beautiful day and we will make you feel very hot/terrific in all", "Today is a beautiful day and we will make you feel very hot/terrific in all your",
"Yesterday was another solid success as news coverage became standard American domestic television hit.", "Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
] ]
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
...@@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"] sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
generation_kwargs = { generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids], "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
...@@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
"num_beams": 2, "num_beams": 2,
} }
output_ids = model.generate(input_ids, **generation_kwargs) output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [ expected_output_string = [
"Today is a beautiful day and a great day for all of us.\n\nI’m", "Today is a beautiful day and a great day for all of us.\n\nI’m",
"Yesterday was the first day of the year for the second time in a row,", "Yesterday was the first time that a person has been arrested in the United States for",
] ]
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_distilgpt2_left_padding(self):
"""Tests that the generated text is the same, regarless of left padding"""
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"do_sample": False,
"repetition_penalty": 1.3,
}
expected_output_string = (
"Today is a beautiful day and I am so happy to be able take part in this amazing event."
)
sentences = ["Today is a beautiful day and"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
# using default length
output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(output_strings[0], expected_output_string)
sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
# longer max length to capture the full length (remember: it is left padded)
output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(output_strings[0], expected_output_string)
@slow @slow
def test_lm_generate_gpt2_greedy_xla(self): def test_lm_generate_gpt2_greedy_xla(self):
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
# the underlying problem)
model = TFGPT2LMHeadModel.from_pretrained("gpt2") model = TFGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
sentences = ["The dog"] sentences = ["The dog", "The flying machine"]
expected_output_strings = [ expected_output_strings = [
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog", "The dog was found in a field near the intersection of West and West Streets.\n\nThe",
"The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
] ]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
output_ids = model.generate(input_ids, do_sample=False) output_ids = model.generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings) self.assertListEqual(output_strings, expected_output_strings)
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=False) output_ids = xla_generate(**input_ids, do_sample=False)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings) self.assertListEqual(output_strings, expected_output_strings)
...@@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
sentence = ["The dog"] sentence = ["The dog", "The flying machine"]
expected_output_string = [ expected_output_string = [
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most" "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
" puppies" " puppies",
"The flying machine was made by an artist who found it difficult to control it as it did not use",
] ]
expected_output_string_xla = [ expected_output_string_xla = [
"The dog has been named in connection with the murder of a 20-year-old man in!" "The dog has been named in connection with the murder of a 20-year-old man in",
"The flying machine is a new and improved system to operate and operate a new system and system "
"system system",
] ]
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0]) output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string) self.assertListEqual(output_strings, expected_output_string)
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0]) output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_string_xla) self.assertListEqual(output_strings, expected_output_string_xla)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment