"sgl-router/src/vscode:/vscode.git/clone" did not exist on "d3be97104b09bfcabc7de507bfe8a79455ebce30"
Unverified Commit 568e5783 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: contrastive search uses existing abstractions and conventions (#19896)

parent 803475fb
......@@ -54,7 +54,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from .modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
from .modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from .models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
......@@ -1882,28 +1882,34 @@ class GenerationMixin:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
output = self(**model_inputs, output_hidden_states=True, output_attentions=True)
outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
# past_key_values is required for fast decoding
if "past_key_values" not in output:
if "past_key_values" not in outputs:
raise ValueError(
f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
"for contrastive search."
)
past_key_values = output.past_key_values
past_key_values = outputs.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.config.is_encoder_decoder:
last_hidden_states = output.decoder_hidden_states[-1]
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = output.hidden_states[-1]
last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = output.logits[:, -1, :]
logit_for_next_step = outputs.logits[:, -1, :]
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
......@@ -1918,6 +1924,18 @@ class GenerationMixin:
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logit_for_next_step,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# enlarge the past_key_values
new_key_values = []
for layer in past_key_values:
......@@ -1937,10 +1955,7 @@ class GenerationMixin:
# build next attention mask
if "attention_mask" in model_inputs:
attention_mask = model_inputs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = model_kwargs["attention_mask"] # [B, S]
attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1))
else:
attention_mask = None
......@@ -1958,27 +1973,26 @@ class GenerationMixin:
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs)
past_key_values = output.past_key_values
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
past_key_values = outputs.past_key_values
logits = output.logits[:, -1, :]
logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = output.hidden_states[-1]
full_hidden_states = output.hidden_states
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
)
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
......@@ -1988,11 +2002,11 @@ class GenerationMixin:
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
decoder_hidden_states_one_step = []
decoder_hidden_states = []
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k))
layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states_one_step.append(layer)
decoder_hidden_states.append(layer)
# select the past_key_value
new_key_values = []
......@@ -2009,21 +2023,24 @@ class GenerationMixin:
past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
# contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values,
# logit_for_next_step, selected_score, decoder_hidden_states_one_step
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=past_key_values,
decoder_hidden_states=decoder_hidden_states,
)
else:
outputs = CausalLMOutputWithPast(
past_key_values=past_key_values,
hidden_states=decoder_hidden_states,
attentions=model_kwargs["attention_mask"],
)
# contrastive_search main logic end
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (selected_scores,)
if output_hidden_states:
decoder_hidden_states += (decoder_hidden_states_one_step,)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
......@@ -2032,14 +2049,6 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if self.config.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=past_key_values,
)
else:
outputs = CausalLMOutputWithCrossAttentions(
past_key_values=past_key_values, attentions=model_kwargs["attention_mask"]
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
......@@ -3884,17 +3893,18 @@ def top_k_top_p_filtering(
return logits
def ranking_fast(
def _ranking_fast(
context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
alpha: float,
beam_width: int,
) -> Tuple[torch.FloatTensor]:
) -> torch.FloatTensor:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
row in the batch.
"""
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
......@@ -3902,5 +3912,5 @@ def ranking_fast(
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
selected_scores, selected_idx = contrastive_score.max(dim=-1) # [B]
return torch.log(selected_scores), selected_idx
_, selected_idx = contrastive_score.max(dim=-1) # [B]
return selected_idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment