test_modeling_tf_gpt2.py 32 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

Matt's avatar
Matt committed
16
17
from __future__ import annotations

18
19
import unittest

Aymeric Augustin's avatar
Aymeric Augustin committed
20
from transformers import GPT2Config, is_tf_available
21
from transformers.testing_utils import require_tf, require_tf2onnx, slow
thomwolf's avatar
thomwolf committed
22

Yih-Dar's avatar
Yih-Dar committed
23
24
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
25
from ...test_pipeline_mixin import PipelineTesterMixin
Yih-Dar's avatar
Yih-Dar committed
26
from ...utils.test_modeling_tf_core import TFCoreModelTesterMixin
thomwolf's avatar
thomwolf committed
27
28


29
if is_tf_available():
thomwolf's avatar
thomwolf committed
30
    import tensorflow as tf
31

32
    from transformers import GPT2Tokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
33
    from transformers.models.gpt2.modeling_tf_gpt2 import (
34
        TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
35
        TFGPT2DoubleHeadsModel,
36
        TFGPT2ForSequenceClassification,
37
38
        TFGPT2LMHeadModel,
        TFGPT2Model,
39
    )
40
    from transformers.tf_utils import shape_list
thomwolf's avatar
thomwolf committed
41
42


43
44
class TFGPT2ModelTester:
    def __init__(
Lysandre's avatar
Lysandre committed
45
46
        self,
        parent,
47
48
49
50
51
52
53
54
55
56
57
    ):
        self.parent = parent
        self.batch_size = 13
        self.seq_length = 7
        self.is_training = True
        self.use_token_type_ids = True
        self.use_input_mask = True
        self.use_labels = True
        self.use_mc_token_ids = True
        self.vocab_size = 99
        self.hidden_size = 32
58
        self.num_hidden_layers = 2
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        self.num_attention_heads = 4
        self.intermediate_size = 37
        self.hidden_act = "gelu"
        self.hidden_dropout_prob = 0.1
        self.attention_probs_dropout_prob = 0.1
        self.max_position_embeddings = 512
        self.type_vocab_size = 16
        self.type_sequence_label_size = 2
        self.initializer_range = 0.02
        self.num_labels = 3
        self.num_choices = 4
        self.scope = None
        self.bos_token_id = self.vocab_size - 1
        self.eos_token_id = self.vocab_size - 1
73
        self.pad_token_id = self.vocab_size - 1
74
75
76
77
78
79

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
80
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        mc_token_ids = None
        if self.use_mc_token_ids:
            mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = GPT2Config(
            vocab_size=self.vocab_size,
            n_embd=self.hidden_size,
            n_layer=self.num_hidden_layers,
            n_head=self.num_attention_heads,
            # intermediate_size=self.intermediate_size,
            # hidden_act=self.hidden_act,
            # hidden_dropout_prob=self.hidden_dropout_prob,
            # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            n_positions=self.max_position_embeddings,
            # type_vocab_size=self.type_vocab_size,
            # initializer_range=self.initializer_range
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
112
113
            pad_token_id=self.pad_token_id,
            return_dict=True,
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        )

        head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)

        return (
            config,
            input_ids,
            input_mask,
            head_mask,
            token_type_ids,
            mc_token_ids,
            sequence_labels,
            token_labels,
            choice_labels,
        )

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    def prepare_config_and_inputs_for_decoder(self):
        (
            config,
            input_ids,
            input_mask,
            head_mask,
            token_type_ids,
            mc_token_ids,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = self.prepare_config_and_inputs()

        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

        return (
            config,
            input_ids,
            input_mask,
            head_mask,
            token_type_ids,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

159
160
161
162
163
164
165
    def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
        model = TFGPT2Model(config=config)
        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
        }
Sylvain Gugger's avatar
Sylvain Gugger committed
166
        result = model(inputs)
167
168

        inputs = [input_ids, None, input_mask]  # None is the input for 'past'
Sylvain Gugger's avatar
Sylvain Gugger committed
169
        result = model(inputs)
170

Sylvain Gugger's avatar
Sylvain Gugger committed
171
        result = model(input_ids)
172

173
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
174
175
176
177
178

    def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
        model = TFGPT2Model(config=config)

        # first forward pass
179
180
181
182
183
184
185
        outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
        outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
        outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)

        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)

186
        output, past_key_values = outputs.to_tuple()
187
188
189
190
191
192
193
194
195

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
        next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)

        # append to next input_ids and token_type_ids
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
        next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)

Sylvain Gugger's avatar
Sylvain Gugger committed
196
        output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
197
198
199
        output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past_key_values)[
            "last_hidden_state"
        ]
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        # select random slice
        random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]

        # test that outputs are equal for slice
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)

    def create_and_check_gpt2_model_attention_mask_past(
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args
    ):
        model = TFGPT2Model(config=config)

        # create attention mask
        half_seq_length = self.seq_length // 2
        attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
        attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
        attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)

        # first forward pass
221
        output, past_key_values = model(input_ids, attention_mask=attn_mask).to_tuple()
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # change a random masked slice from input_ids
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
        random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
        vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
        condition = tf.transpose(
            tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
        )
        input_ids = tf.where(condition, random_other_next_tokens, input_ids)

        # append to next input_ids and attn_mask
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
        attn_mask = tf.concat([attn_mask, tf.ones((shape_list(attn_mask)[0], 1), dtype=tf.int32)], axis=1)

        # get two different outputs
Sylvain Gugger's avatar
Sylvain Gugger committed
240
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
241
242
243
        output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
            "last_hidden_state"
        ]
244
245
246
247
248
249
250
251
252

        # select random slice
        random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]

        # test that outputs are equal for slice
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)

253
254
255
256
257
    def create_and_check_gpt2_model_past_large_inputs(
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args
    ):
        model = TFGPT2Model(config=config)

258
259
260
261
262
        input_ids = input_ids[:1, :]
        input_mask = input_mask[:1, :]
        token_type_ids = token_type_ids[:1, :]
        self.batch_size = 1

263
        # first forward pass
264
        outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True)
265

266
        output, past_key_values = outputs.to_tuple()
267
268
269

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
270
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
271
        next_token_types = ids_tensor((self.batch_size, 3), self.type_vocab_size)
272
273
274

        # append to next input_ids and token_type_ids
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
275
        next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
276
        next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
277
278
279
280
281

        output_from_no_past = model(
            next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
        )["last_hidden_state"]
        output_from_past = model(
282
283
284
285
            next_tokens,
            token_type_ids=next_token_types,
            attention_mask=next_attention_mask,
            past_key_values=past_key_values,
286
        )["last_hidden_state"]
287
288
289
290
291
292
293
294
        self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])

        # select random slice
        random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
        output_from_past_slice = output_from_past[:, :, random_slice_idx]

        # test that outputs are equal for slice
295
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
296

297
298
299
300
301
302
303
    def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
        model = TFGPT2LMHeadModel(config=config)
        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
        }
Sylvain Gugger's avatar
Sylvain Gugger committed
304
        result = model(inputs)
305
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

    def create_and_check_gpt2_double_head(
        self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
    ):
        model = TFGPT2DoubleHeadsModel(config=config)

        multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
        multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
        multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))

        inputs = {
            "input_ids": multiple_choice_inputs_ids,
            "mc_token_ids": mc_token_ids,
            "attention_mask": multiple_choice_input_mask,
            "token_type_ids": multiple_choice_token_type_ids,
        }
Sylvain Gugger's avatar
Sylvain Gugger committed
322
        result = model(inputs)
323
        self.parent.assertEqual(
324
            result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
325
        )
326
        self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    def create_and_check_gpt2_for_sequence_classification(
        self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
    ):
        config.num_labels = self.num_labels
        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
            "labels": sequence_labels,
        }
        model = TFGPT2ForSequenceClassification(config)

        result = model(inputs)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()

        (
            config,
            input_ids,
            input_mask,
            head_mask,
            token_type_ids,
            mc_token_ids,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = config_and_inputs

        inputs_dict = {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": input_mask,
        }
        return config, inputs_dict


366
@require_tf
367
class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
368
369
370
371
372
    all_model_classes = (
        (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel)
        if is_tf_available()
        else ()
    )
373
    all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
374
375
376
377
378
379
380
381
382
383
    pipeline_model_mapping = (
        {
            "feature-extraction": TFGPT2Model,
            "text-classification": TFGPT2ForSequenceClassification,
            "text-generation": TFGPT2LMHeadModel,
            "zero-shot": TFGPT2ForSequenceClassification,
        }
        if is_tf_available()
        else {}
    )
384
    test_head_masking = False
385
386
    test_onnx = True
    onnx_min_opset = 10
thomwolf's avatar
thomwolf committed
387
388

    def setUp(self):
389
        self.model_tester = TFGPT2ModelTester(self)
390
        self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
thomwolf's avatar
thomwolf committed
391
392
393
394
395
396
397
398

    def test_config(self):
        self.config_tester.run_common_tests()

    def test_gpt2_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_model(*config_and_inputs)

399
400
401
402
403
404
405
406
    def test_gpt2_model_past(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_model_past(*config_and_inputs)

    def test_gpt2_model_att_mask_past(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_model_attention_mask_past(*config_and_inputs)

407
408
409
410
    def test_gpt2_model_past_large_inputs(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_model_past_large_inputs(*config_and_inputs)

thomwolf's avatar
thomwolf committed
411
412
413
414
415
416
417
418
    def test_gpt2_lm_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)

    def test_gpt2_double_head(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)

419
420
421
422
    def test_gpt2_sequence_classification_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)

423
    @slow
thomwolf's avatar
thomwolf committed
424
    def test_model_from_pretrained(self):
425
        for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
426
            model = TFGPT2Model.from_pretrained(model_name)
thomwolf's avatar
thomwolf committed
427
            self.assertIsNotNone(model)
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    # overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
    # `batch_dims` > 0"
    @require_tf2onnx
    @slow
    def test_onnx_runtime_optimize(self):
        if not self.test_onnx:
            return

        import onnxruntime
        import tf2onnx

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            # Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
            if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]:
                continue

            model = model_class(config)
Matt's avatar
Matt committed
448
            model.build_in_name_scope()
449
450
451
452
453

            onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)

            onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())

454
455
456
457
458
    # TODO (Joao): fix me
    @unittest.skip("Onnx compliancy broke with TF 2.10")
    def test_onnx_compliancy(self):
        pass

459

460
@require_tf
461
462
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
    @slow
463
    def test_lm_generate_greedy_distilgpt2_batch_special(self):
464
        model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
465
466
467
468
469
470
        tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        sentences = ["Today is a beautiful day and", "Yesterday was"]
471
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
472
473
474
475
476
477
478
479

        generation_kwargs = {
            "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
            "no_repeat_ngram_size": 2,
            "do_sample": False,
            "repetition_penalty": 1.3,
        }

480
        output_ids = model.generate(**input_ids, **generation_kwargs)
481
482
483
484

        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        expected_output_string = [
            "Today is a beautiful day and I am so happy to be able take part in this amazing event.",
485
            "Yesterday was a very interesting time for the world to see how much of this is",
486
487
488
        ]
        self.assertListEqual(output_strings, expected_output_string)

489
490
491
492
493
494
495
496
497
    @slow
    def test_lm_generate_sample_distilgpt2_batch_special(self):
        model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
        tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        sentences = ["Today is a beautiful day and", "Yesterday was"]
498
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
499
500
501
502
503
504
505
506
507

        generation_kwargs = {
            "do_sample": True,
            "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
            "no_repeat_ngram_size": 2,
            "repetition_penalty": 1.3,
            "temperature": 1.5,
            "top_k": 500,
            "top_p": 0.9,
508
            "seed": [42, 0],  # seed set -> deterministic sampling sequence -> deterministic generation
509
510
        }

511
512
        # forces the generation to happen on CPU, to avoid GPU-related quirks
        with tf.device(":/CPU:0"):
513
            output_ids = model.generate(**input_ids, **generation_kwargs)
514

515
516
517
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        expected_output_string = [
518
519
            "Today is a beautiful day and we will make you feel very hot/terrific in all your",
            "Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
520
521
522
        ]
        self.assertListEqual(output_strings, expected_output_string)

523
524
525
526
527
528
529
530
531
    @slow
    def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
        model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
        tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        sentences = ["Today is a beautiful day and", "Yesterday was"]
532
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
533
534
535
536
537
538
539
540

        generation_kwargs = {
            "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
            "no_repeat_ngram_size": 2,
            "do_sample": False,
            "num_beams": 2,
        }

541
        output_ids = model.generate(**input_ids, **generation_kwargs)
542
543
544

        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        expected_output_string = [
545
            "Today is a beautiful day and a great day for all of us.\n\nI鈥檓",
546
            "Yesterday was the first time that a person has been arrested in the United States for",
547
548
549
        ]
        self.assertListEqual(output_strings, expected_output_string)

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    @slow
    def test_lm_generate_distilgpt2_left_padding(self):
        """Tests that the generated text is the same, regarless of left padding"""
        model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
        tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        generation_kwargs = {
            "bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
            "no_repeat_ngram_size": 2,
            "do_sample": False,
            "repetition_penalty": 1.3,
        }
        expected_output_string = (
            "Today is a beautiful day and I am so happy to be able take part in this amazing event."
        )

        sentences = ["Today is a beautiful day and"]
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
        # using default length
        output_ids = model.generate(**input_ids, **generation_kwargs)
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertEqual(output_strings[0], expected_output_string)

        sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
        # longer max length to capture the full length (remember: it is left padded)
        output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertEqual(output_strings[0], expected_output_string)

583
    @slow
584
    def test_lm_generate_gpt2_greedy_xla(self):
585
        model = TFGPT2LMHeadModel.from_pretrained("gpt2")
586
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
587

588
589
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
Matt's avatar
Matt committed
590

591
        sentences = ["The dog", "The flying machine"]
592
        expected_output_strings = [
593
594
            "The dog was found in a field near the intersection of West and West Streets.\n\nThe",
            "The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
595
        ]
596
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
Matt's avatar
Matt committed
597

598
        output_ids = model.generate(**input_ids, do_sample=False)
599
600
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(output_strings, expected_output_strings)
Matt's avatar
Matt committed
601

602
        xla_generate = tf.function(model.generate, jit_compile=True)
603
        output_ids = xla_generate(**input_ids, do_sample=False)
604
605
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(output_strings, expected_output_strings)
606
607

    @slow
608
609
610
611
    def test_lm_generate_gpt2_sample_xla(self):
        # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
        # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
        # and that we can seed both versions.
612

Joao Gante's avatar
Joao Gante committed
613
614
615
616
617
618
619
620
        # forces the generation to happen on CPU, to avoid GPU-related quirks
        with tf.device(":/CPU:0"):
            model = TFGPT2LMHeadModel.from_pretrained("gpt2")
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"

621
            sentence = ["The dog", "The flying machine"]
Joao Gante's avatar
Joao Gante committed
622
            expected_output_string = [
Sylvain Gugger's avatar
Sylvain Gugger committed
623
                "The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
624
625
                " puppies",
                "The flying machine was made by an artist who found it difficult to control it as it did not use",
Joao Gante's avatar
Joao Gante committed
626
627
            ]
            expected_output_string_xla = [
628
629
630
                "The dog has been named in connection with the murder of a 20-year-old man in",
                "The flying machine is a new and improved system to operate and operate a new system and system "
                "system system",
Joao Gante's avatar
Joao Gante committed
631
            ]
632
            input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
Joao Gante's avatar
Joao Gante committed
633

634
            output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
Joao Gante's avatar
Joao Gante committed
635
636
637
638
            output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            self.assertListEqual(output_strings, expected_output_string)

            xla_generate = tf.function(model.generate, jit_compile=True)
639
            output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
Joao Gante's avatar
Joao Gante committed
640
641
            output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            self.assertListEqual(output_strings, expected_output_string_xla)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

    @slow
    def test_lm_generate_gpt2_beam_search_xla(self):
        model = TFGPT2LMHeadModel.from_pretrained("gpt2")
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        sentences = ["The dog", "The flying machine"]
        expected_output_strings = [
            "The dog was found in the backyard of a home in the 6500 block of South Main Street",
            "The flying machine is a very powerful machine, but it's not a very powerful machine. It's",
        ]
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True)

        output_ids = model.generate(**input_ids, do_sample=False, num_beams=2)
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(output_strings, expected_output_strings)

        xla_generate = tf.function(model.generate, jit_compile=True)
        output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2)
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(output_strings, expected_output_strings)
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

    @slow
    def test_contrastive_search_gpt2(self):
        article = (
            "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
            "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
        )

        gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
        gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
        input_ids = gpt2_tokenizer(article, return_tensors="tf")

        outputs = gpt2_model.generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)

        generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)

        self.assertListEqual(
            generated_text,
            [
                "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
                "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
                "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
                "Google Now, which helps users find the information they're looking for on the web. But the company "
                "is not the only one to collect data on its users. Facebook, for example, has its own facial "
                "recognition technology, as well as a database of millions of photos that it uses to personalize its "
                "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
                "concerned about the company's ability to keep users' information private. In a blog post last "
                'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
                'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
                'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
                'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
                "but said in a statement to The Associated Press that"
            ],
        )

    @slow
    def test_contrastive_search_gpt2_xla(self):
        article = (
            "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
            "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
        )

        gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
        gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
        input_ids = gpt2_tokenizer(article, return_tensors="tf")

        xla_generate = tf.function(gpt2_model.generate, jit_compile=True)
        outputs = xla_generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)

        generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)

        self.assertListEqual(
            generated_text,
            [
                "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
                "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
                "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
                "Google Now, which helps users find the information they're looking for on the web. But the company "
                "is not the only one to collect data on its users. Facebook, for example, has its own facial "
                "recognition technology, as well as a database of millions of photos that it uses to personalize its "
                "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
                "concerned about the company's ability to keep users' information private. In a blog post last "
                'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
                'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
                'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
                'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
                "but said in a statement to The Associated Press that"
            ],
        )