Unverified Commit a1bbcf3f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactoring the generate() function (#6949)

* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
parent b63beb74
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers.generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
@require_torch
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores
def test_min_lenght_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = F.softmax(scores, dim=-1)
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
# uniform distribution stays uniform
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
# sharp peaks get higher, valleys get lower
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
# smooth peaks get lower, valleys get higher
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
def test_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, scores.clone())
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = (
torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
)
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = TopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
scores = top_k_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
)
top_p_warp = TopPLogitsWarper(0.7)
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
)
def test_no_bad_words_dist_processor(self):
vocab_size = 5
batch_size = 2
eos_token_id = 4
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
# Note that 5th element cannot be forbidden as it is EOS token
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]]
)
# check edge case
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 0
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.clone()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.clone()
# instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3)
top_p_warp = TopPLogitsWarper(0.8)
no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
# with processor list
processor = LogitsProcessorList(
[
min_dist_proc,
temp_dist_warp,
rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc,
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp)
# scores should be equal
self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
This diff is collapsed.
......@@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
......@@ -128,7 +129,7 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -357,11 +358,12 @@ class BertModelTester:
@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
......@@ -373,6 +375,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
def setUp(self):
self.model_tester = BertModelTester(self)
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -183,9 +184,10 @@ class BertGenerationEncoderTester:
@require_torch
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
all_generative_model_classes = (BertGenerationDecoder,) if is_torch_available() else ()
def setUp(self):
self.model_tester = BertGenerationEncoderTester(self)
......
......@@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = ["Sam"]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
......@@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
......@@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
generated_ids = self.model.generate(**model_inputs)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
......@@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
def test_90_generation_from_short_input(self):
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
generated_utterances = self.model.generate(**model_inputs)
# generated_txt = self.tokenizer.decode(generated_utterances[0])
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
assert clean_txt in (
"have you ever been to a sam club? it's a great club in the south.",
......
......@@ -44,7 +44,6 @@ if is_torch_available():
BertModel,
PretrainedConfig,
PreTrainedModel,
top_k_top_p_filtering,
)
......@@ -882,126 +881,6 @@ class ModelTesterMixin:
with torch.no_grad():
model(**inputs)[0]
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined, model needs input_ids
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True))
else:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
with self.assertRaises(AssertionError):
# generating multiple sequences when no beam search generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
# num_return_sequences > 1, sample
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
torch_device
)
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
else:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
with self.assertRaises(AssertionError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# num_return_sequences > 1, sample
self._check_generated_ids(
model.generate(
input_ids,
do_sample=True,
num_beams=2,
num_return_sequences=2,
)
)
# num_return_sequences > 1, greedy
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
# special tokens cannot be bad tokens
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
# create random bad tokens that are not special tokens
bad_tokens = []
while len(bad_tokens) < num_bad_tokens:
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
if token not in special_tokens:
bad_tokens.append(token)
return bad_tokens
def _check_generated_ids(self, output_ids):
for token_id in output_ids[0].tolist():
self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size)
def _check_match_tokens(self, generated_ids, bad_words_ids):
# for all bad word tokens
for bad_word_ids in bad_words_ids:
# for all slices in batch
for generated_ids_slice in generated_ids:
# for all word idx
for i in range(len(bad_word_ids), len(generated_ids_slice)):
# if tokens match
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
return True
return False
@require_torch_multigpu
def test_multigpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase):
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276, # 5th highest value; idx. 9
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cumulative prob of 5 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958, # 5th highest value; idx. 18
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cumulative prob of 5 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
7.3534,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
5.4403,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
......@@ -19,6 +19,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
......@@ -151,7 +152,7 @@ class CTRLModelTester:
@require_torch
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
......
......@@ -24,6 +24,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
......@@ -120,7 +121,7 @@ def prepare_fsmt_inputs_dict(
@require_torch
class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -377,7 +378,7 @@ class GPT2ModelTester:
@require_torch
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
......@@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2(self):
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
def test_gpt2_sample(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
expected_output_ids = [
464,
1893,
286,
262,
1578,
1829,
11,
290,
262,
1893,
286,
262,
1578,
7526,
11,
423,
587,
287,
262,
2635,
] # The president of the United States, and the president of the United Kingdom, have been in the White
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
torch.manual_seed(0)
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = (
"Today is a nice day and if you don't know anything about the state of play during your holiday"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
......@@ -170,7 +171,7 @@ class OpenAIGPTModelTester:
@require_torch
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
......
......@@ -22,6 +22,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
......@@ -853,7 +854,7 @@ class ProphetNetStandaloneEncoderModelTester:
@require_torch
class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
......@@ -917,7 +918,7 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
test_pruning = False
......
......@@ -26,6 +26,7 @@ from transformers.testing_utils import (
)
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -196,11 +197,14 @@ class ReformerModelTester:
)
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
model.train()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
loss.backward()
......@@ -569,7 +573,7 @@ class ReformerTesterMixin:
@require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
......@@ -629,7 +633,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -267,7 +268,7 @@ class RobertaModelTester:
@require_torch
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
......@@ -282,6 +283,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
def setUp(self):
self.model_tester = RobertaModelTester(self)
......
......@@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
......@@ -466,7 +467,7 @@ class T5ModelTester:
@require_torch
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
......@@ -592,6 +593,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
do_sample=False,
early_stopping=True,
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries,
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
......@@ -156,7 +157,7 @@ class TransfoXLModelTester:
@require_torch
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False
......
......@@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
......@@ -331,7 +332,7 @@ class XLMModelTester:
@require_torch
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
......
......@@ -21,6 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
......@@ -479,7 +480,7 @@ class XLNetModelTester:
@require_torch
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment