Unverified Commit d39f794e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: contrastive search cosmetic tweaks (#19871)

parent 0a772491
...@@ -100,8 +100,9 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): ...@@ -100,8 +100,9 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
@dataclass @dataclass
class ContrastiveSearchEncoderDecoderOutput(ModelOutput): class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
""" """
Args:
Base class for outputs of decoder-only generation models using contrastive search. Base class for outputs of decoder-only generation models using contrastive search.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`. if all batches finished early due to the `eos_token_id`.
...@@ -110,7 +111,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): ...@@ -110,7 +111,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`:
is passed or when `config.output_hidden_states=True`): is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
...@@ -124,8 +125,9 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): ...@@ -124,8 +125,9 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
@dataclass @dataclass
class ContrastiveSearchDecoderOnlyOutput(ModelOutput): class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
""" """
Args:
Base class for outputs of decoder-only generation models using contrastive search. Base class for outputs of decoder-only generation models using contrastive search.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`. if all batches finished early due to the `eos_token_id`.
...@@ -433,6 +435,8 @@ GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoder ...@@ -433,6 +435,8 @@ GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoder
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
class GenerationMixin: class GenerationMixin:
...@@ -1010,7 +1014,7 @@ class GenerationMixin: ...@@ -1010,7 +1014,7 @@ class GenerationMixin:
begin_suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None, forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following Generates sequences of token ids for models with a language modeling head. The method supports the following
...@@ -1766,7 +1770,7 @@ class GenerationMixin: ...@@ -1766,7 +1770,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = False,
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]: ) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r""" r"""
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
...@@ -1781,6 +1785,10 @@ class GenerationMixin: ...@@ -1781,6 +1785,10 @@ class GenerationMixin:
logits_processor (`LogitsProcessorList`, *optional*): logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step. used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*): stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop. used to tell if the generation loop should stop.
...@@ -1817,7 +1825,6 @@ class GenerationMixin: ...@@ -1817,7 +1825,6 @@ class GenerationMixin:
>>> from transformers import ( >>> from transformers import (
... AutoTokenizer, ... AutoTokenizer,
... AutoModelForCausalLM, ... AutoModelForCausalLM,
... MinLengthLogitsProcessor,
... StoppingCriteriaList, ... StoppingCriteriaList,
... MaxLengthCriteria, ... MaxLengthCriteria,
... ) ... )
...@@ -1859,7 +1866,6 @@ class GenerationMixin: ...@@ -1859,7 +1866,6 @@ class GenerationMixin:
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
step_counter = 0
while True: while True:
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
...@@ -1875,20 +1881,23 @@ class GenerationMixin: ...@@ -1875,20 +1881,23 @@ class GenerationMixin:
model_kwargs["use_cache"] = True model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
if step_counter == 0: # (2) last_hidden_states; (3) logit_for_next_step
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` if model_kwargs.get("past") is None:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
output = self(**model_inputs, output_hidden_states=True, output_attentions=True) output = self(**model_inputs, output_hidden_states=True, output_attentions=True)
# past_key_values is activated for fast decoding # past_key_values is required for fast decoding
if "past_key_values" not in output: if "past_key_values" not in output:
raise ValueError( raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for" f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
" contrastive search." "for contrastive search."
) )
past_key_values = output.past_key_values past_key_values = output.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
last_hidden_states = output.decoder_hidden_states[-1] last_hidden_states = output.decoder_hidden_states[-1]
else: else:
...@@ -1897,7 +1906,8 @@ class GenerationMixin: ...@@ -1897,7 +1906,8 @@ class GenerationMixin:
logit_for_next_step = output.logits[:, -1, :] logit_for_next_step = output.logits[:, -1, :]
# contrastive_search main logic start: # contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
bsz, seqlen, embed_dim = last_hidden_states.size() bsz, seqlen, embed_dim = last_hidden_states.size()
# logits processor # logits processor
...@@ -1949,12 +1959,6 @@ class GenerationMixin: ...@@ -1949,12 +1959,6 @@ class GenerationMixin:
) )
# compute the candidate tokens by the language model and collects their hidden_states # compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs) output = self(output_hidden_states=True, **next_model_inputs)
if "past_key_values" not in output:
raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive"
" search."
)
past_key_values = output.past_key_values past_key_values = output.past_key_values
logits = output.logits[:, -1, :] logits = output.logits[:, -1, :]
...@@ -1969,13 +1973,16 @@ class GenerationMixin: ...@@ -1969,13 +1973,16 @@ class GenerationMixin:
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim) last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
) )
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# the scores and index of the selected tokens are returned # the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast( selected_scores, selected_idx = ranking_fast(
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
) )
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
next_hidden = next_hidden[range(bsz), selected_idx, :] next_hidden = next_hidden[range(bsz), selected_idx, :]
...@@ -2003,7 +2010,8 @@ class GenerationMixin: ...@@ -2003,7 +2010,8 @@ class GenerationMixin:
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
# contrastive_search main logic end:: # contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step # after running the above codes, we update following parameters: next_tokens, past_key_values,
# logit_for_next_step, selected_score, decoder_hidden_states_one_step
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
...@@ -2047,10 +2055,6 @@ class GenerationMixin: ...@@ -2047,10 +2055,6 @@ class GenerationMixin:
else: else:
this_peer_finished = True this_peer_finished = True
# prepare model inputs
model_kwargs["past_key_values"] = past_key_values
step_counter += 1
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput( return ContrastiveSearchEncoderDecoderOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment