Unverified Commit 568e5783 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: contrastive search uses existing abstractions and conventions (#19896)

parent 803475fb
...@@ -54,7 +54,7 @@ from .generation_stopping_criteria import ( ...@@ -54,7 +54,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList, StoppingCriteriaList,
validate_stopping_criteria, validate_stopping_criteria,
) )
from .modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput from .modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from .models.auto import ( from .models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
...@@ -1882,28 +1882,34 @@ class GenerationMixin: ...@@ -1882,28 +1882,34 @@ class GenerationMixin:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None: if model_kwargs.get("past") is None:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs` # the `encoder_outputs`
output = self(**model_inputs, output_hidden_states=True, output_attentions=True) outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
# past_key_values is required for fast decoding # past_key_values is required for fast decoding
if "past_key_values" not in output: if "past_key_values" not in outputs:
raise ValueError( raise ValueError(
f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used " f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
"for contrastive search." "for contrastive search."
) )
past_key_values = output.past_key_values past_key_values = outputs.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens) # previous tokens)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
last_hidden_states = output.decoder_hidden_states[-1] last_hidden_states = outputs.decoder_hidden_states[-1]
else: else:
last_hidden_states = output.hidden_states[-1] last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens # next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = output.logits[:, -1, :] logit_for_next_step = outputs.logits[:, -1, :]
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# contrastive_search main logic start: # contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
...@@ -1918,6 +1924,18 @@ class GenerationMixin: ...@@ -1918,6 +1924,18 @@ class GenerationMixin:
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k) _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logit_for_next_step,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# enlarge the past_key_values # enlarge the past_key_values
new_key_values = [] new_key_values = []
for layer in past_key_values: for layer in past_key_values:
...@@ -1937,10 +1955,7 @@ class GenerationMixin: ...@@ -1937,10 +1955,7 @@ class GenerationMixin:
# build next attention mask # build next attention mask
if "attention_mask" in model_inputs: if "attention_mask" in model_inputs:
attention_mask = model_inputs["attention_mask"] # [B, S] attention_mask = model_kwargs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1)) attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1))
else: else:
attention_mask = None attention_mask = None
...@@ -1958,27 +1973,26 @@ class GenerationMixin: ...@@ -1958,27 +1973,26 @@ class GenerationMixin:
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
) )
# compute the candidate tokens by the language model and collects their hidden_states # compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs) outputs = self(
past_key_values = output.past_key_values **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
past_key_values = outputs.past_key_values
logits = output.logits[:, -1, :] logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models # name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1] next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states full_hidden_states = outputs.decoder_hidden_states
else: else:
next_hidden = output.hidden_states[-1] next_hidden = outputs.hidden_states[-1]
full_hidden_states = output.hidden_states full_hidden_states = outputs.hidden_states
context_hidden = ( context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim) last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
) )
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence # model confidence
# the scores and index of the selected tokens are returned selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_scores, selected_idx = ranking_fast(
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
...@@ -1988,11 +2002,11 @@ class GenerationMixin: ...@@ -1988,11 +2002,11 @@ class GenerationMixin:
next_hidden = next_hidden[range(bsz), selected_idx, :] next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
decoder_hidden_states_one_step = [] decoder_hidden_states = []
for layer in full_hidden_states: for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k)) layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k))
layer = layer[range(bsz), selected_idx, :] layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states_one_step.append(layer) decoder_hidden_states.append(layer)
# select the past_key_value # select the past_key_value
new_key_values = [] new_key_values = []
...@@ -2009,21 +2023,24 @@ class GenerationMixin: ...@@ -2009,21 +2023,24 @@ class GenerationMixin:
past_key_values = new_key_values past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
# contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values, # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
# logit_for_next_step, selected_score, decoder_hidden_states_one_step if self.config.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=past_key_values,
decoder_hidden_states=decoder_hidden_states,
)
else:
outputs = CausalLMOutputWithPast(
past_key_values=past_key_values,
hidden_states=decoder_hidden_states,
attentions=model_kwargs["attention_mask"],
)
# contrastive_search main logic end
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (selected_scores,)
if output_hidden_states:
decoder_hidden_states += (decoder_hidden_states_one_step,)
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: if pad_token_id is None:
...@@ -2032,14 +2049,6 @@ class GenerationMixin: ...@@ -2032,14 +2049,6 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if self.config.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=past_key_values,
)
else:
outputs = CausalLMOutputWithCrossAttentions(
past_key_values=past_key_values, attentions=model_kwargs["attention_mask"]
)
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
...@@ -3884,17 +3893,18 @@ def top_k_top_p_filtering( ...@@ -3884,17 +3893,18 @@ def top_k_top_p_filtering(
return logits return logits
def ranking_fast( def _ranking_fast(
context_hidden: torch.FloatTensor, context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor, next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor, next_top_k_probs: torch.FloatTensor,
alpha: float, alpha: float,
beam_width: int, beam_width: int,
) -> Tuple[torch.FloatTensor]: ) -> torch.FloatTensor:
""" """
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
row in the batch.
""" """
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
...@@ -3902,5 +3912,5 @@ def ranking_fast( ...@@ -3902,5 +3912,5 @@ def ranking_fast(
next_top_k_probs = next_top_k_probs.view(-1) # [B*K] next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
selected_scores, selected_idx = contrastive_score.max(dim=-1) # [B] _, selected_idx = contrastive_score.max(dim=-1) # [B]
return torch.log(selected_scores), selected_idx return selected_idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment