test_modeling_flax_gpt2.py 15.5 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tempfile
import unittest

import numpy as np

import transformers
Patrick von Platen's avatar
Patrick von Platen committed
22
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
Suraj Patil's avatar
Suraj Patil committed
23
24
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow

25
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
Yih-Dar's avatar
Yih-Dar committed
26
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
Suraj Patil's avatar
Suraj Patil committed
27
28
29
30
31


if is_flax_available():
    import jax
    import jax.numpy as jnp
32

Suraj Patil's avatar
Suraj Patil committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
    from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model

if is_torch_available():
    import torch


class FlaxGPT2ModelTester:
    def __init__(
        self,
        parent,
        batch_size=14,
        seq_length=7,
        is_training=True,
        use_input_mask=True,
        use_token_type_ids=False,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
55
        num_hidden_layers=2,
Suraj Patil's avatar
Suraj Patil committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        initializer_range=0.02,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_mask = use_input_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.scope = None
        self.bos_token_id = vocab_size - 1
        self.eos_token_id = vocab_size - 1
        self.pad_token_id = vocab_size - 1

86
    def prepare_config_and_inputs(self):
Suraj Patil's avatar
Suraj Patil committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = random_attention_mask([self.batch_size, self.seq_length])

        config = GPT2Config(
            vocab_size=self.vocab_size,
            n_embd=self.hidden_size,
            n_layer=self.num_hidden_layers,
            n_head=self.num_attention_heads,
            n_positions=self.max_position_embeddings,
            use_cache=False,
            bos_token_id=self.bos_token_id,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
        )

        return (config, input_ids, input_mask)

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        config, input_ids, attention_mask = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
        return config, inputs_dict

113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def prepare_config_and_inputs_for_decoder(self):
        config, input_ids, attention_mask = self.prepare_config_and_inputs()

        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

        return (
            config,
            input_ids,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
        )

Suraj Patil's avatar
Suraj Patil committed
127
128
129
130
131
    def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
        max_decoder_length = 20
        model = model_class_name(config)

        past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
Patrick von Platen's avatar
Patrick von Platen committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")

        position_ids = jnp.broadcast_to(
            jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
        )
        outputs_cache = model(
            input_ids[:, :-1],
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
        outputs_cache_next = model(
            input_ids[:, -1:],
            attention_mask=attention_mask,
            past_key_values=outputs_cache.past_key_values,
            position_ids=position_ids,
        )
Suraj Patil's avatar
Suraj Patil committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        outputs = model(input_ids)

        diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
        self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")

    def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
        max_decoder_length = 20
        model = model_class_name(config)

        attention_mask_cache = jnp.concatenate(
            [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
            axis=-1,
        )

        past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
Patrick von Platen's avatar
Patrick von Platen committed
167
168
169
        position_ids = jnp.broadcast_to(
            jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
        )
Suraj Patil's avatar
Suraj Patil committed
170

Patrick von Platen's avatar
Patrick von Platen committed
171
172
173
174
175
176
177
        outputs_cache = model(
            input_ids[:, :-1],
            attention_mask=attention_mask_cache,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )
        position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
Suraj Patil's avatar
Suraj Patil committed
178
        outputs_cache_next = model(
Patrick von Platen's avatar
Patrick von Platen committed
179
180
181
182
            input_ids[:, -1:],
            past_key_values=outputs_cache.past_key_values,
            attention_mask=attention_mask_cache,
            position_ids=position_ids,
Suraj Patil's avatar
Suraj Patil committed
183
184
185
186
187
188
189
        )

        outputs = model(input_ids, attention_mask=attention_mask)

        diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
        self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def check_bool_attention_mask_in_generation(self, model_class_name, config, input_ids, attention_mask):
        model = model_class_name(config)

        output_int_att_mask = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=3,
        )

        output_bool_att_mask = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask.astype(bool),
            max_new_tokens=3,
        )

        self.parent.assertTrue(
            (output_bool_att_mask.sequences == output_int_att_mask.sequences).all(),
            "Generated response differ between boolean and integer attention mask",
        )

Suraj Patil's avatar
Suraj Patil committed
210
211

@require_flax
Patrick von Platen's avatar
Patrick von Platen committed
212
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
Suraj Patil's avatar
Suraj Patil committed
213
    all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
Patrick von Platen's avatar
Patrick von Platen committed
214
    all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
Suraj Patil's avatar
Suraj Patil committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    def setUp(self):
        self.model_tester = FlaxGPT2ModelTester(self)

    def test_use_cache_forward(self):
        for model_class_name in self.all_model_classes:
            config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
            self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)

    def test_use_cache_forward_with_attn_mask(self):
        for model_class_name in self.all_model_classes:
            config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
            self.model_tester.check_use_cache_forward_with_attn_mask(
                model_class_name, config, input_ids, attention_mask
            )

231
232
233
234
235
236
237
    def test_bool_attention_mask_in_generation(self):
        for model_class_name in self.all_generative_model_classes:
            config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
            self.model_tester.check_bool_attention_mask_in_generation(
                model_class_name, config, input_ids, attention_mask
            )

Patrick von Platen's avatar
Patrick von Platen committed
238
239
240
    @slow
    def test_batch_generation(self):
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
241
        inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
Patrick von Platen's avatar
Patrick von Platen committed
242
243
244
245
246
247
248
249
250
251
252
253

        model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
        model.do_sample = False
        model.config.pad_token_id = model.config.eos_token_id

        jit_generate = jax.jit(model.generate)

        output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences

        output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

        expected_string = [
254
            "Hello this is a long string of words. I'm going to start with the first one.\n",
Patrick von Platen's avatar
Patrick von Platen committed
255
256
257
258
            "Hey, I'm not sure if I'm going to be able to do",
        ]

        self.assertListEqual(output_string, expected_string)
Suraj Patil's avatar
Suraj Patil committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364

    # overwrite from common since `attention_mask` in combination
    # with `causal_mask` behaves slighly differently
    @is_pt_flax_cross_test
    def test_equivalence_pt_to_flax(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                batch_size, seq_length = pt_inputs["input_ids"].shape
                rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
                for batch_idx, start_index in enumerate(rnd_start_indices):
                    pt_inputs["attention_mask"][batch_idx, :start_index] = 0
                    pt_inputs["attention_mask"][batch_idx, start_index:] = 1
                    prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
                    prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
                fx_model.params = fx_state

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
                    self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)

    # overwrite from common since `attention_mask` in combination
    # with `causal_mask` behaves slighly differently
    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
                batch_size, seq_length = pt_inputs["input_ids"].shape
                rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
                for batch_idx, start_index in enumerate(rnd_start_indices):
                    pt_inputs["attention_mask"][batch_idx, :start_index] = 0
                    pt_inputs["attention_mask"][batch_idx, start_index:] = 1
                    prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
                    prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
                    self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)

    @slow
    def test_model_from_pretrained(self):
        for model_class_name in self.all_model_classes:
            model = model_class_name.from_pretrained("gpt2", from_pt=True)
            outputs = model(np.ones((1, 1)))
            self.assertIsNotNone(outputs)