Unverified Commit 71786b10 authored by GMFTBY's avatar GMFTBY Committed by GitHub
Browse files

Adding the state-of-the-art contrastive search decoding methods for the...

Adding the state-of-the-art contrastive search decoding methods for the codebase of generation_utils.py (#19477)

* add: the contrastive search for generaton_utils

* add: testing scripts for contrastive search under examples/text-generation

* update the quality of codes

* revise the docstring; make the generation_contrastive_search.py scripts;

* revise the examples/pytorch/text-generation/run_generation_contrastive_search.py to the auto-APIs format

* revise the necessary documents

* fix: revise the docstring of generation_contrastive_search.py

* Fix the code indentation

* fix: revise the nits and examples in contrastive_search docstring.

* fix the copyright

* delete generation_contrastive_search.py

* revise the logic in contrastive_search

* update the intergration test and the docstring

* run the tests over

* add the slow decorate to the contrastive_search intergrate test

* add more test

* do the style, quality, consistency checks
parent fc5fdc10
...@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License. ...@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`], This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
[`~generation_utils.GenerationMixin.greedy_search`], [`~generation_utils.GenerationMixin.greedy_search`],
[`~generation_utils.GenerationMixin.contrastive_search`],
[`~generation_utils.GenerationMixin.sample`], [`~generation_utils.GenerationMixin.sample`],
[`~generation_utils.GenerationMixin.beam_search`], [`~generation_utils.GenerationMixin.beam_search`],
[`~generation_utils.GenerationMixin.beam_sample`], [`~generation_utils.GenerationMixin.beam_sample`],
......
...@@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme ...@@ -26,6 +26,7 @@ Each framework has a generate method for auto-regressive text generation impleme
- sample - sample
- beam_search - beam_search
- beam_sample - beam_sample
- contrastive_search
- group_beam_search - group_beam_search
- constrained_beam_search - constrained_beam_search
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 University of Cambridge, Tencent AI Lab, DeepMind and The University of Hong Kong Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" The examples of running contrastive search on the auto-APIs;
Running this example:
python run_generation_contrastive_search.py --model_name_or_path=gpt2-large --penalty_alpha=0.6 --k=4 --length=256
"""
import argparse
import logging
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--penalty_alpha", type=float, default=0.0)
parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
set_seed(args)
# Initialize the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
# tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
# model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
model.to(args.device)
if args.fp16:
model.half()
logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
inputs = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
inputs = {key: value.to(args.device) for key, value in inputs.items()}
output_sequences = model.generate(
**inputs,
max_length=args.length + len(inputs["input_ids"][0]),
penalty_alpha=args.penalty_alpha,
top_k=args.k,
)
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, add_special_tokens=False)
# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(inputs["input_ids"][0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
print(total_sequence)
return generated_sequences
if __name__ == "__main__":
main()
This diff is collapsed.
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BartForConditionalGeneration, BartForConditionalGeneration,
...@@ -34,8 +35,10 @@ if is_torch_available(): ...@@ -34,8 +35,10 @@ if is_torch_available():
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Tokenizer, GPT2Tokenizer,
ImageGPTForCausalImageModeling, ImageGPTForCausalImageModeling,
OPTForCausalLM,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel, VisionEncoderDecoderModel,
pipeline, pipeline,
top_k_top_p_filtering, top_k_top_p_filtering,
...@@ -1693,6 +1696,140 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1693,6 +1696,140 @@ class GenerationIntegrationTests(unittest.TestCase):
], ],
) )
@slow
def test_contrastive_search_bart(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos, 39, pleaded not guilty to two counts of "offering a false instrument" Prosecutors say the marriages were part of an immigration scam. In total, Barriento has been married 10 times, with nine of her marriages occurring between 1999 and 2002."""
],
)
@slow
def test_contrastive_search_t5(self):
article = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.
A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.
Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.
In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.
Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the
2010 marriage license application, according to court documents.
Prosecutors said the marriages were part of an immigration scam.
On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.
After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective
Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.
All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.
Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.
Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.
The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s
Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.
Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.
If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.
"""
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for permanent residence after the marriages, prosecutors say."""
],
)
@slow
def test_contrastive_search_opt(self):
article = r"""A chat between a curious human and the Statue of Liberty.
Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?"""
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-6.7b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: Since 1884.\nHuman: Why did you come to America?\nStatue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.\nHuman: What do you think of America?\nStatue: I love it. It is the greatest country in the world.\nHuman: What’s the weather like in New York?\nStatue: It is cold.\nHuman: Is it safe to walk around at night?\nStatue: Yes. There are policemen everywhere.\nHuman: Do you have any children?\nStatue: Not yet. My pedestal is empty.\nHuman: What would you like to say to people who want to immigrate to America?\nStatue: Come on over. You will be happy here. We have everything you need.\nSource: http://www.statueofliberty.org/index.cf"""
],
)
@slow
def test_contrastive_search_gptj(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
opt_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
opt_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google\'s parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating a company that would apply deep learning to problems in healthcare, energy, transportation, and other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 million in cash and stock.[3] The acquisition was seen as a move to strengthen Google\'s position in the fast-growing field of artificial intelligence (AI), which it had invested in since 2010.[4] Google CEO Larry Page said that the company was "excited to have DeepMind on board" and that "this is a step towards our goal of building AI that works for everyone, not just a few".[5]\n\nDeepMind\'s co-founders, Demis Hassabis and Mustafa Suleyman, were named CEO and C"""
],
)
@slow
def test_contrastive_search_gpt2(self):
article = """DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"""
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"""DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as Google Now, which helps users find the information they\'re looking for on the web. But the company is not the only one to collect data on its users. Facebook, for example, has its own facial recognition technology, as well as a database of millions of photos that it uses to personalize its News Feed.\n\nFacebook\'s use of data is a hot topic in the tech industry, with privacy advocates concerned about the company\'s ability to keep users\' information private. In a blog post last year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, but said in a statement to The Associated Press that"""
],
)
def test_max_length_backward_compat_greedy(self): def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
...@@ -2050,6 +2187,134 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2050,6 +2187,134 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 56])
max_new_tokens = 3
t5_model.config.max_length = 20
t5_model.config.eos_token_id = None
# Encoder decoder call
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 56 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 59])
# Encoder decoder call > 20
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
t5_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 29])
max_new_tokens = 3
bart_model.config.max_length = 20
bart_model.config.eos_token_id = None
# Encoder decoder call
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS + 20 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
bart_model.generate(
decoder_input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4
)
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
article = """Justin Timberlake."""
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gptj_model.config.max_length = 20
# call < 20
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gptj_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 9])
max_new_tokens = 3
gpt2_model.config.max_length = 20
# call < 20
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
# 9 input_ids + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 12])
# call > 20
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
# 1 BOS token + 23 new tokens
self.assertEqual(list(outputs.shape), [1, 24])
# max_new_tokens and max_length serve the same purpose and must not be used together.
with self.assertRaises(ValueError):
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20, penalty_alpha=0.6, top_k=4)
def test_max_new_tokens_decoder_only(self): def test_max_new_tokens_decoder_only(self):
article = """Justin Timberlake.""" article = """Justin Timberlake."""
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment