Unverified Commit 8a133490 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Add TF generate sample tests with all logit processors (#15852)

* Add GPT2 TF generate sample test with all logits processor

* Add T5 generate sample test
parent 40040727
......@@ -442,7 +442,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2_batch_special(self):
def test_lm_generate_greedy_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
......@@ -468,6 +468,37 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_sample_distilgpt2_batch_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"do_sample": True,
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.3,
"temperature": 1.5,
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and this makes finding holiday travel easier for you to do other project\nOh",
"Yesterday was an enjoyable but especially great note though it certainly upset many Democrats who say",
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_gpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
......
......@@ -480,6 +480,33 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"do_sample": True,
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
"no_repeat_ngram_size": 3,
"repetition_penalty": 2.2,
"temperature": 0.8,
"top_k": 500,
"top_p": 0.9,
}
tf.random.set_seed(42) # deterministic sampling sequence -> deterministic generation
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["i love her I really love my heart", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)
@require_tf
@require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment