test_modeling_flax_common.py 20 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import inspect
Sylvain Gugger's avatar
Sylvain Gugger committed
17
import random
18
import tempfile
19
from typing import List, Tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24

import numpy as np

import transformers
from transformers import is_flax_available, is_torch_available
Daniel Stancl's avatar
Daniel Stancl committed
25
from transformers.models.auto import get_values
Vasudev Gupta's avatar
Vasudev Gupta committed
26
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
29
30
31
32
33


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
34
    import jaxlib.xla_extension as jax_xla
Daniel Stancl's avatar
Daniel Stancl committed
35
    from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
36
37
38
39
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43
44
45
46

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


Daniel Stancl's avatar
Daniel Stancl committed
47
48
49
50
51
52
53
54
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
        if "_range" in key or "_std" in key or "initializer_factor" in key:
            setattr(configs_no_init, key, 1e-10)
    return configs_no_init


Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


Suraj Patil's avatar
Suraj Patil committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return np.array(values, dtype=jnp.float32).reshape(shape)


Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
92
93
94
95
def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


96
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()
Daniel Stancl's avatar
Daniel Stancl committed
100
    is_encoder_decoder = False
Sylvain Gugger's avatar
Sylvain Gugger committed
101

102
103
104
105
106
107
108
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
109
                if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
110
111
                else v
                for k, v in inputs_dict.items()
112
113
114
115
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
116
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
117
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
118
119
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assert_almost_equals(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

155
    @is_pt_flax_cross_test
156
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
157
158
159
160
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
161
                # prepare inputs
162
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
163
164
165
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
166
167
168
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

169
                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
170
171
172
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
173
                fx_model = model_class(config, dtype=jnp.float32)
174

175
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
176
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
180

181
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
184
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Sylvain Gugger's avatar
Sylvain Gugger committed
185

186
187
188
189
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

190
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
191
192
193
194
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
195
                    self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
212
213
214
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
215
216
217
218
219
220
221
222
223
224
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

225
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
226
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
Daniel Stancl's avatar
Daniel Stancl committed
227

228
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
229
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
230
231
232
233
234
235
236
237
238
239
240
241

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
242
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
243
244

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
245
246
247
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
248
249
250
            if model_class.__name__ != "FlaxBertModel":
                continue

Sylvain Gugger's avatar
Sylvain Gugger committed
251
            with self.subTest(model_class.__name__):
252
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
253

254
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
255
                outputs = model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
256

257
                # verify that normal save_pretrained works as expected
258
259
260
261
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

262
                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
263
264
265
266
267
268
269
270
271
272
                for output_loaded, output in zip(outputs_loaded, outputs):
                    self.assert_almost_equals(output_loaded, output, 1e-3)

                # verify that save_pretrained for distributed training
                # with `params=params` works as expected
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname, params=model.params)
                    model_loaded = model_class.from_pretrained(tmpdirname)

                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
273
                for output_loaded, output in zip(outputs_loaded, outputs):
274
                    self.assert_almost_equals(output_loaded, output, 1e-3)
275

Vasudev Gupta's avatar
Vasudev Gupta committed
276
    @slow
277
278
279
280
281
    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
282
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
283
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
284
285

                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
286
                def model_jitted(input_ids, attention_mask=None, **kwargs):
287
                    return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
288
289

                with self.subTest("JIT Enabled"):
290
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
291
292
293

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
294
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
295
296
297

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
Daniel Stancl's avatar
Daniel Stancl committed
298

Sylvain Gugger's avatar
Sylvain Gugger committed
299
                    self.assertEqual(jitted_output.shape, output.shape)
300

301
302
303
304
305
306
307
308
309
    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.__call__)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

Daniel Stancl's avatar
Daniel Stancl committed
310
311
312
313
314
315
316
317
318
319
320
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
            else:
                expected_arg_names = ["input_ids", "attention_mask"]
                self.assertListEqual(arg_names[:2], expected_arg_names)
321

322
323
324
325
326
327
328
329
330
331
    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)
332
333
334
335
336
337

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)

            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
338
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
339

Daniel Stancl's avatar
Daniel Stancl committed
340
341
342
343
344
345
346
347
348
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)

            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
            else:
                seq_length = self.model_tester.seq_length
349
350
351
352
353
354

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
            )

Daniel Stancl's avatar
Daniel Stancl committed
355
356
357
358
359
360
361
362
363
364
365
366
367
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

368
369
370
371
372
373
374
375
376
377
378
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)
379
380
381
382
383
384

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        seq_length = getattr(self.model_tester, "seq_length", None)
Daniel Stancl's avatar
Daniel Stancl committed
385
386
387
388
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
389
390
391
392
393
394

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
395
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
396
397
398
399
400
401
402
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
403
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
404
405
406
407
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
408
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
409
410
411
            )
            out_len = len(outputs)

Daniel Stancl's avatar
Daniel Stancl committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            if self.is_encoder_decoder:
                correct_outlen = 5

                # Question Answering model returns start_logits and end_logits
                if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output

                self.assertEqual(out_len, correct_outlen)

                # decoder attentions
                decoder_attentions = outputs.decoder_attentions
                self.assertIsInstance(decoder_attentions, (list, tuple))
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
                )

                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

443
444
445
446
447
448
            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))

Daniel Stancl's avatar
Daniel Stancl committed
449
450
451
452
453
454
            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
455
456
457
458
459
460
461
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(self_attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
462
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
463
            )