Unverified Commit 5cce3076 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: generate without `tf.TensorArray` (#17801)

parent ab223fc1
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -1979,6 +1978,8 @@ class TFGenerationMixin: ...@@ -1979,6 +1978,8 @@ class TFGenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
use_xla = not tf.executing_eagerly() use_xla = not tf.executing_eagerly()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None scores = [] if (return_dict_in_generate and output_scores) else None
...@@ -1989,34 +1990,25 @@ class TFGenerationMixin: ...@@ -1989,34 +1990,25 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape batch_size, cur_len = input_ids.shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray( input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
element_shape=(batch_size,), generated = tf.concat([input_ids, input_ids_padding], axis=-1)
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated
for i in range(cur_len):
generated = generated.write(i, input_ids[:, i])
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function # 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn # define condition fn
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
"""state termination condition fn.""" """state termination condition fn."""
return ~tf.reduce_all(finished_sequences) return ~tf.reduce_all(finished_sequences)
# define condition fn # define condition fn
def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
"""state update fn.""" """state update fn."""
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) if model_kwargs.get("past") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token logits # forward pass to get next token logits
outputs = self( outputs = self(
**model_inputs, **model_inputs,
...@@ -2043,8 +2035,7 @@ class TFGenerationMixin: ...@@ -2043,8 +2035,7 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states) decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution # pre-process distribution
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
# argmax # argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
...@@ -2057,8 +2048,8 @@ class TFGenerationMixin: ...@@ -2057,8 +2048,8 @@ class TFGenerationMixin:
finished_sequences = finished_sequences | (next_tokens == eos_token_id) finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len` # update `generated` and `cur_len`
generated = generated.write(cur_len, next_tokens) update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
next_tokens = next_tokens[:, None] generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1 cur_len += 1
# update model_kwargs # update model_kwargs
...@@ -2073,34 +2064,29 @@ class TFGenerationMixin: ...@@ -2073,34 +2064,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors # let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None) model_kwargs.pop("past", None)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) return generated, finished_sequences, cur_len, model_kwargs
next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
# 5. run generation # 5. run generation
# 1st generation step has to be run before to initialize `past` # 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn( generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs generated, finished_sequences, cur_len, model_kwargs
) )
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
generated, _, _, cur_len, _ = tf.while_loop( generated, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn, greedy_search_cond_fn,
greedy_search_body_fn, greedy_search_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs), (generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
# 6. prepare outputs # 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla: if not use_xla:
# cut for backward compatibility # cut for backward compatibility
output_ids = output_ids[:, :cur_len] generated = generated[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -2117,7 +2103,7 @@ class TFGenerationMixin: ...@@ -2117,7 +2103,7 @@ class TFGenerationMixin:
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFGreedySearchEncoderDecoderOutput( return TFGreedySearchEncoderDecoderOutput(
sequences=output_ids, sequences=generated,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2127,13 +2113,13 @@ class TFGenerationMixin: ...@@ -2127,13 +2113,13 @@ class TFGenerationMixin:
) )
else: else:
return TFGreedySearchDecoderOnlyOutput( return TFGreedySearchDecoderOnlyOutput(
sequences=output_ids, sequences=generated,
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
return output_ids return generated
def sample( def sample(
self, self,
...@@ -2250,6 +2236,8 @@ class TFGenerationMixin: ...@@ -2250,6 +2236,8 @@ class TFGenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
use_xla = not tf.executing_eagerly() use_xla = not tf.executing_eagerly()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None scores = [] if (return_dict_in_generate and output_scores) else None
...@@ -2261,29 +2249,20 @@ class TFGenerationMixin: ...@@ -2261,29 +2249,20 @@ class TFGenerationMixin:
batch_size, cur_len = input_ids.shape batch_size, cur_len = input_ids.shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences` # initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray( input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
element_shape=(batch_size,), generated = tf.concat([input_ids, input_ids_padding], axis=-1)
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated
for i in range(cur_len):
generated = generated.write(i, input_ids[:, i])
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function # 4. define "xla-compile-able" stop-condition and auto-regressive function
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
return ~tf.reduce_all(finished_sequences) return ~tf.reduce_all(finished_sequences)
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs) if model_kwargs.get("past") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token logits # forward pass to get next token logits
outputs = self( outputs = self(
**model_inputs, **model_inputs,
...@@ -2310,9 +2289,8 @@ class TFGenerationMixin: ...@@ -2310,9 +2289,8 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states) decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution # pre-process distribution
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len) next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
# sample # sample
if seed is not None: if seed is not None:
...@@ -2334,8 +2312,8 @@ class TFGenerationMixin: ...@@ -2334,8 +2312,8 @@ class TFGenerationMixin:
finished_sequences = finished_sequences | (next_tokens == eos_token_id) finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len` # update `generated` and `cur_len`
generated = generated.write(cur_len, next_tokens) update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
next_tokens = next_tokens[:, None] generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1 cur_len += 1
# update model_kwargs # update model_kwargs
...@@ -2350,34 +2328,29 @@ class TFGenerationMixin: ...@@ -2350,34 +2328,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors # let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None) model_kwargs.pop("past", None)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) return generated, finished_sequences, cur_len, model_kwargs
next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
# 5. run generation # 5. run generation
# 1st generation step has to be run before to initialize `past` # 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn( generated, finished_sequences, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs generated, finished_sequences, cur_len, model_kwargs
) )
# 2-to-n generation steps can then be run in autoregressive fashion # 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though # only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs): if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
generated, _, _, cur_len, _ = tf.while_loop( generated, _, cur_len, _ = tf.while_loop(
sample_cond_fn, sample_cond_fn,
sample_body_fn, sample_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs), (generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
# 6. prepare outputs # 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla: if not use_xla:
# cut for backward compatibility # cut for backward compatibility
output_ids = output_ids[:, :cur_len] generated = generated[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -2394,7 +2367,7 @@ class TFGenerationMixin: ...@@ -2394,7 +2367,7 @@ class TFGenerationMixin:
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFSampleEncoderDecoderOutput( return TFSampleEncoderDecoderOutput(
sequences=output_ids, sequences=generated,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2404,13 +2377,13 @@ class TFGenerationMixin: ...@@ -2404,13 +2377,13 @@ class TFGenerationMixin:
) )
else: else:
return TFSampleDecoderOnlyOutput( return TFSampleDecoderOnlyOutput(
sequences=output_ids, sequences=generated,
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
return output_ids return generated
def beam_search( def beam_search(
self, self,
...@@ -2585,6 +2558,8 @@ class TFGenerationMixin: ...@@ -2585,6 +2558,8 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis # GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None scores = [] if (return_dict_in_generate and output_scores) else None
...@@ -2594,41 +2569,13 @@ class TFGenerationMixin: ...@@ -2594,41 +2569,13 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function # 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape batch_size, num_beams, cur_len = input_ids.shape
input_ids_length = cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id` # per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences = tf.TensorArray( input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
element_shape=(batch_size, num_beams), pad_token_id or 0
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
intermediary_running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams * 2),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
) )
if pad_token_id: # ignores the cases when it is 0 or None running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
for i in range(max_length): sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0)
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
intermediary_running_sequences = intermediary_running_sequences.write(
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
)
# write prompt to running_sequences
for i in range(cur_len):
running_sequences = running_sequences.write(i, input_ids[:, :, i])
# per batch,beam-item state bit indicating if sentence has finished. # per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool) is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
...@@ -2656,7 +2603,6 @@ class TFGenerationMixin: ...@@ -2656,7 +2603,6 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
input_ids_length,
model_kwargs, model_kwargs,
): ):
""" """
...@@ -2685,27 +2631,18 @@ class TFGenerationMixin: ...@@ -2685,27 +2631,18 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
input_ids_length,
model_kwargs, model_kwargs,
intermediary_running_sequences=None,
): ):
""" """
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far seen so far
""" """
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
# 1. Forward current tokens # 1. Forward current tokens
if model_kwargs.get("past") is None or needs_full_input:
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last input_ids = running_sequences[:, :, :cur_len]
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0]) else:
input_token = tf.slice( input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
running_sequences_seq_last, model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs)
(0, 0, cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length),
)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
model_outputs = self( model_outputs = self(
**model_inputs, **model_inputs,
return_dict=True, return_dict=True,
...@@ -2734,9 +2671,7 @@ class TFGenerationMixin: ...@@ -2734,9 +2671,7 @@ class TFGenerationMixin:
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and # get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores. # add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits) log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor( log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2) log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2] vocab_size = log_probs.shape[2]
...@@ -2755,23 +2690,28 @@ class TFGenerationMixin: ...@@ -2755,23 +2690,28 @@ class TFGenerationMixin:
beams_to_keep = 2 * num_beams beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep) topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size topk_beam_indices = topk_indices // vocab_size
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices) topk_running_sequences = gather_beams(running_sequences, topk_beam_indices)
topk_ids = topk_indices % vocab_size topk_ids = topk_indices % vocab_size
# writes the new token # writes the new token
intermediary_running_sequences = intermediary_running_sequences.unstack( indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep])
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1]) indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size])
update_indices = tf.stack(
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
)
topk_sequences = tf.tensor_scatter_nd_update(
tensor=topk_running_sequences,
indices=update_indices,
updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]),
) )
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])
# 4. Check which sequences have ended # 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker? # Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences # To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value. # set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id
if eos_token_id is None: if eos_token_id is None:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape) eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to( did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0), tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape, eos_in_next_token.shape,
...@@ -2785,8 +2725,8 @@ class TFGenerationMixin: ...@@ -2785,8 +2725,8 @@ class TFGenerationMixin:
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams # Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams). # (from top 2*k beams).
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1] next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences_seq_last, next_running_scores = gather_beams( next_running_sequences, next_running_scores = gather_beams(
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices [topk_sequences, running_topk_log_probs], next_topk_indices
) )
# 6. Process topk logits # 6. Process topk logits
...@@ -2807,18 +2747,18 @@ class TFGenerationMixin: ...@@ -2807,18 +2747,18 @@ class TFGenerationMixin:
# 7. Get scores, sequences, is sentence finished for next. # 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores # Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams # to existing finished scores and select the best from the new set of beams
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0]) merged_sequences = tf.concat([sequences, topk_sequences], axis=1)
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
merged_scores = tf.concat([scores, topk_log_probs], axis=1) merged_scores = tf.concat([scores, topk_log_probs], axis=1)
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1) merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1] topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams( next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
) )
# 8. Prepare data for the next iteration # 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k # Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches. # beam-associated caches.
cur_len = cur_len + 1
if "past_key_values" in model_outputs: if "past_key_values" in model_outputs:
cache = tf.nest.map_structure( cache = tf.nest.map_structure(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis), lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
...@@ -2841,35 +2781,20 @@ class TFGenerationMixin: ...@@ -2841,35 +2781,20 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input # if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None: if model_kwargs.get("past", None) is None:
next_input_ids_length = cur_len + 1
# let's throw out `past` since we don't want `None` tensors # let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None) model_kwargs.pop("past", None)
else:
next_input_ids_length = 1
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
next_running_sequences = running_sequences.unstack(
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
)
return ( return (
cur_len + 1, cur_len,
next_running_sequences, next_running_sequences,
next_running_scores, next_running_scores,
next_sequences, next_sequences,
next_scores, next_scores,
next_is_sent_finished, next_is_sent_finished,
next_input_ids_length,
next_model_kwargs, next_model_kwargs,
) )
# 5. run generation # 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn = partial(
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
)
# 1st generation step has to be run before to initialize `past` (if active) # 1st generation step has to be run before to initialize `past` (if active)
( (
cur_len, cur_len,
...@@ -2878,66 +2803,38 @@ class TFGenerationMixin: ...@@ -2878,66 +2803,38 @@ class TFGenerationMixin:
sequences, sequences,
scores, scores,
is_sent_finished, is_sent_finished,
input_ids_length,
model_kwargs, model_kwargs,
) = beam_search_body_fn( ) = beam_search_body_fn(
cur_len, cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
) )
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though) # NOT yield EOS token though)
if beam_search_cond_fn( if beam_search_cond_fn(
cur_len, cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
): ):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop( cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
beam_search_cond_fn, beam_search_cond_fn,
beam_search_body_fn, beam_search_body_fn,
( (cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
# 6. prepare outputs # 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return # Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item. # running sequences for that batch item.
none_finished = tf.math.reduce_any(is_sent_finished, axis=1) none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last) sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
scores = tf.where(none_finished[:, None], scores, running_scores) scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in ascending order) # Take best beams for each batch (the score is sorted in ascending order)
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :]) sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences]) scores = flatten_beam_dim(scores[:, :num_return_sequences])
if not use_xla: if not use_xla:
# Cut for backward compatibility # Cut for backward compatibility
sequences_seq_last = sequences_seq_last[:, :cur_len] sequences = sequences[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -2948,7 +2845,7 @@ class TFGenerationMixin: ...@@ -2948,7 +2845,7 @@ class TFGenerationMixin:
) )
return TFBeamSearchEncoderDecoderOutput( return TFBeamSearchEncoderDecoderOutput(
sequences=sequences_seq_last, sequences=sequences,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2958,13 +2855,13 @@ class TFGenerationMixin: ...@@ -2958,13 +2855,13 @@ class TFGenerationMixin:
) )
else: else:
return TFBeamSearchDecoderOnlyOutput( return TFBeamSearchDecoderOnlyOutput(
sequences=sequences_seq_last, sequences=sequences,
scores=scores, scores=scores,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
return sequences_seq_last return sequences
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
......
...@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past = [None for _ in range(len(past))] new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0]) slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype) attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# correct 5 here # -1 because current_pos has already been incremented before this function
new_past_index = current_pos - 1 # -1 again because last index = len - 1
new_past_index = current_pos - 2
for i in range(len(past)): for i in range(len(past)):
update_slice = past[i][:, :, :, -1:] update_slice = past[i][:, :, :, -1:]
......
...@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
...@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset = 2 offset = 2
if past: if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else: else:
inputs = tf.concat([inputs, dummy_token], axis=1) input_ids = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token # Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1] sequence_length = input_ids.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1)) perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1)) perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
...@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = { inputs = {
"input_ids": inputs, "input_ids": input_ids,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_mems": use_mems, "use_mems": use_mems,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment