"git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "62424a74133a0cf7f0472c4d6b2a3a5dc00fd68e"
Unverified Commit 71786b10 authored by GMFTBY's avatar GMFTBY Committed by GitHub
Browse files

Adding the state-of-the-art contrastive search decoding methods for the...

Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py (#19477)

* add: the contrastive search for generaton_utils

* add: testing scripts for contrastive search under examples/text-generation

* update the quality of codes

* revise the docstring; make the generation_contrastive_search.py scripts;

* revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format

* revise the necessary documents

* fix: revise the docstring of generation_contrastive_search.py

* Fix the code indentation

* fix: revise the nits and examples in contrastive_search docstring.

* fix the copyright

* delete generation_contrastive_search.py

* revise the logic in contrastive_search

* update the intergration test and the docstring

* run the tests over

* add the slow decorate to the contrastive_search intergrate test

* add more test

* do the style, quality, consistency checks
parent fc5fdc10
......@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`],
......
......@@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The examples of running contrastive search on the auto-APIs;
Running this example:
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
"""
import argparse
import logging
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--penalty_alpha", type=float, default=0.0)
parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
set_seed(args)
# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
model.to(args.device)
if args.fp16:
model.half()
logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
inputs = {key: value.to(args.device) for key, value in inputs.items()}
output_sequences = model.generate(
**inputs,
max_length=args.length + len(inputs["input_ids"][0]),
penalty_alpha=args.penalty_alpha,
top_k=args.k,
)
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)
# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
print(total_sequence)
return generated_sequences
if __name__ == "__main__":
main()
......@@ -54,6 +54,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from .modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
from .models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
......@@ -96,6 +97,54 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@dataclass
class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
"""
Args:
Base class for outputs of decoder-only generation models using contrastive search.
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when
`config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`
is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@dataclass
class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
"""
Args:
Base class for outputs of decoder-only generation models using contrastive search.
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when
`config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@dataclass
class GreedySearchEncoderDecoderOutput(ModelOutput):
"""
......@@ -393,6 +442,8 @@ class GenerationMixin:
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if `penalty_alpha>0`
and `top_k>1`
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
......@@ -921,6 +972,7 @@ class GenerationMixin:
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
penalty_alpha: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
......@@ -966,6 +1018,8 @@ class GenerationMixin:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation_utils.GenerationMixin.contrastive_search`] if
`penalty_alpha>0.` and `top_k>1`
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
......@@ -1011,6 +1065,8 @@ class GenerationMixin:
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to `model.config.temperature` or 1.0 if the config does not set any value):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*, defaults to `model.config.penalty_alpha` or None if the config does not set any value):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to `model.config.top_k` or 50 if the config does not set any value):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to `model.config.top_p` or 1.0 if the config does not set any value):
......@@ -1329,19 +1385,45 @@ class GenerationMixin:
# 6. determine generation mode
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
is_contrastive_search_gen_mode = (
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
)
is_greedy_gen_mode = (
(num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_sample_gen_mode = (
(num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_beam_sample_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is True
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups > 1)
and not is_constraint_gen_mode
and not is_contrastive_search_gen_mode
)
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
......@@ -1411,6 +1493,27 @@ class GenerationMixin:
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
)
return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
synced_gpus=synced_gpus,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
......@@ -1646,6 +1749,324 @@ class GenerationMixin:
**model_kwargs,
)
@torch.no_grad()
def contrastive_search(
self,
input_ids: torch.LongTensor,
top_k: Optional[int] = 1,
penalty_alpha: Optional[float] = 0,
logits_processor: Optional[LogitsProcessorList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
top_k (`int`, *optional*, defaults to 1):
The size of the candidate set that is used to re-rank for contrastive search
penalty_alpha (`float`, *optional*, defaults to 0):
The degeneration penalty for contrastive search; activate when it is larger than 0
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_utils.ContrastiveSearchDecoderOnlyOutput`],
[`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor`
containing the generated tokens (default behaviour) or a
[`~generation_utils.ContrastiveSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.ContrastiveSearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMind’s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it"]
```"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
this_peer_finished = False # used by synced_gpus only
step_counter = 0
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step
if step_counter == 0:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs`
output = self(**model_inputs, output_hidden_states=True, output_attentions=True)
# past_key_values is activated for fast decoding
if "past_key_values" not in output:
raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for"
" contrastive search."
)
past_key_values = output.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens)
if self.config.is_encoder_decoder:
last_hidden_states = output.decoder_hidden_states[-1]
else:
last_hidden_states = output.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = output.logits[:, -1, :]
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty
bsz, seqlen, embed_dim = last_hidden_states.size()
# logits processor
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
# enlarge the past_key_values
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz, num_head, seq_len, esz = item.size()
item = (
item.unsqueeze(1)
.expand(-1, top_k, -1, -1, -1)
.reshape(bsz * top_k, num_head, seq_len, esz)
.contiguous()
) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
past_key_values = new_key_values
# build next attention mask
if "attention_mask" in model_inputs:
attention_mask = model_inputs["attention_mask"] # [B, S]
# decoder-only model need the full attention mask, not only the mask for the last token
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((bsz, 1))], dim=-1)
attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1))
else:
attention_mask = None
# encoder-decoder model also contains the `encoder_outputs`
if self.config.is_encoder_decoder and "encoder_outputs" in model_inputs:
encoder_outputs = model_inputs["encoder_outputs"]
else:
encoder_outputs = None
next_model_inputs = self.prepare_inputs_for_generation(
top_k_ids.view(-1, 1),
past=past_key_values,
attention_mask=attention_mask,
use_cache=True,
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs)
if "past_key_values" not in output:
raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive"
" search."
)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = output.decoder_hidden_states[-1]
full_hidden_states = output.decoder_hidden_states
else:
next_hidden = output.hidden_states[-1]
full_hidden_states = output.hidden_states
context_hidden = (
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
decoder_hidden_states_one_step = []
for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k))
layer = layer[range(bsz), selected_idx, :]
decoder_hidden_states_one_step.append(layer)
# select the past_key_value
new_key_values = []
for layer in past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam // top_k)
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
# contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (selected_scores,)
if output_hidden_states:
decoder_hidden_states += (decoder_hidden_states_one_step,)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if self.config.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=past_key_values,
)
else:
outputs = CausalLMOutputWithCrossAttentions(
past_key_values=past_key_values, attentions=model_kwargs["attention_mask"]
)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
else:
this_peer_finished = True
# prepare model inputs
model_kwargs["past_key_values"] = past_key_values
step_counter += 1
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
decoder_hidden_states=decoder_hidden_states,
)
else:
return ContrastiveSearchDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def greedy_search(
self,
input_ids: torch.LongTensor,
......@@ -3457,3 +3878,25 @@ def top_k_top_p_filtering(
)
return logits
def ranking_fast(
context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
alpha: float,
beam_width: int,
) -> Tuple[torch.FloatTensor]:
"""
context_hidden: bsz*beam x seqlen x embed_dim next_hidden: bsz*beam x 1 x embed_dim next_top_k_probs: bsz x beam
"""
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
selected_scores, selected_idx = contrastive_score.max(dim=-1) # [B]
return torch.log(selected_scores), selected_idx
......@@ -27,6 +27,7 @@ if is_torch_available():
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration,
......@@ -34,8 +35,10 @@ if is_torch_available():
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
OPTForCausalLM,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel,
pipeline,
top_k_top_p_filtering,
......@@ -1693,6 +1696,140 @@ class GenerationIntegrationTests(unittest.TestCase):
],
)
@slow
def test_contrastive_search_bart(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002."""
],
)
@slow
def test_contrastive_search_t5(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say."""
],
)
@slow
def test_contrastive_search_opt(self):
article = r"""A chat between a curious human and the Statue of Liberty.
Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?"""
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-6.7b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf"""
],
)
@slow
def test_contrastive_search_gptj(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
opt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
opt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C"""
],
)
@slow
def test_contrastive_search_gpt2(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that"""
],
)
def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
......@@ -2050,6 +2187,134 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 56])
max_new_tokens = 3
t5_model.config.max_length = 20
t5_model.config.eos_token_id = None
# Encoder decoder call
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 56 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 59])
# Encoder decoder call > 20
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gptj_model.config.max_length = 20
# call < 20
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment