"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "700329493dba077654cd771044ec303a265a8d79"
Unverified Commit 588faad1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF XLA beam sample (#20927)

* beam sample in beam search

* rag now works with the updated beam search

* delete legacy (non-XLA) generation code related to beam sample
parent 375801d5
This diff is collapsed.
...@@ -817,6 +817,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -817,6 +817,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return reordered_past return reordered_past
@staticmethod
def _gather_beams(nested, beam_indices, batch_axis=0):
"""
RAG-specific `_gather_beams`: gathers the beam slices indexed by beam_indices into new beam array. If the
nested tensor has a shape mismatch with the beam indices, then it means it is the cache. In that case, isolates
and takes care of the extra dimension for ndocs.
"""
def gather_fn(tensor):
is_rag_cache = tensor.shape[0] != beam_indices.shape[0]
if is_rag_cache:
n_docs = tensor.shape[0] // beam_indices.shape[0]
batch_size = beam_indices.shape[0]
# reshapes into (batch size, num beams, n_docs, ...), the cache format expected by RAG
tensor = tf.reshape(tensor, (batch_size, -1, n_docs, *tensor.shape[2:]))
gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
if is_rag_cache:
# reshapes back into the shape expected by beam search
gathered_tensor = tf.reshape(gathered_tensor, (batch_size * n_docs, -1, *gathered_tensor.shape[3:]))
return gathered_tensor
return tf.nest.map_structure(gather_fn, nested)
def marginalize(self, seq_logits, doc_scores, n_docs=None): def marginalize(self, seq_logits, doc_scores, n_docs=None):
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
...@@ -1129,12 +1155,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1129,12 +1155,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
model_kwargs["output_scores"] = output_scores
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["encoder_attentions"] = None
model_kwargs["encoder_hidden_states"] = None
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
...@@ -1211,71 +1231,55 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1211,71 +1231,55 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
doc_scores = tf.repeat(doc_scores, num_beams, axis=0) doc_scores = tf.repeat(doc_scores, num_beams, axis=0)
# define start_len & additional parameters # define start_len & additional parameters
cur_len = 1
vocab_size = self.config.generator.vocab_size
model_kwargs["doc_scores"] = doc_scores model_kwargs["doc_scores"] = doc_scores
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
model_kwargs["attention_mask"] = context_attention_mask
model_kwargs["n_docs"] = n_docs model_kwargs["n_docs"] = n_docs
# not needed. TODO(PVP): change after generate refactor pre_processor = self._get_logits_processor(
do_sample = False repetition_penalty=self.config.repetition_penalty,
temperature = self.config.temperature no_repeat_ngram_size=no_repeat_ngram_size,
top_k = self.config.top_k bad_words_ids=bad_words_ids,
top_p = self.config.top_p min_length=min_length,
repetition_penalty = self.config.repetition_penalty max_length=max_length,
eos_token_id=eos_token_id,
if num_beams > 1: forced_bos_token_id=self.config.generator.forced_bos_token_id,
return self._generate_beam_search( forced_eos_token_id=self.config.generator.forced_eos_token_id,
decoder_input_ids, input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
cur_len=cur_len, )
if num_beams == 1:
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
batch_size=batch_size, logits_processor=pre_processor,
num_return_sequences=num_return_sequences, output_attentions=output_attentions,
length_penalty=length_penalty, output_hidden_states=output_hidden_states,
num_beams=num_beams, output_scores=output_scores,
vocab_size=vocab_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
forced_bos_token_id=None,
forced_eos_token_id=None,
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version **model_kwargs,
)
else:
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
) )
model_kwargs["attention_mask"] = context_attention_mask elif num_beams > 1:
if num_beams < num_return_sequences:
raise ValueError(
"Beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
)
if model_kwargs.get("encoder_attentions", None) is None: def unflatten_beam_dim(tensor):
model_kwargs.pop("encoder_attentions", None) """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
if model_kwargs.get("encoder_hidden_states", None) is None: shape = shape_list(tensor)
model_kwargs.pop("encoder_hidden_states", None) return tf.reshape(tensor, [-1, num_beams] + shape[1:])
model_kwargs.pop("output_hidden_states", None) decoder_input_ids = unflatten_beam_dim(decoder_input_ids)
model_kwargs.pop("output_attentions", None) model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"])
model_kwargs.pop("output_scores", None) model_kwargs["encoder_outputs"]["last_hidden_state"] = unflatten_beam_dim(
model_kwargs["encoder_outputs"]["last_hidden_state"]
)
return self.greedy_search( return self.beam_search(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
...@@ -1287,6 +1291,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1287,6 +1291,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return_dict_in_generate=return_dict_in_generate, return_dict_in_generate=return_dict_in_generate,
**model_kwargs, **model_kwargs,
) )
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings() return self.rag.generator.get_input_embeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment