test_modeling_flax_common.py 20.2 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import inspect
Sylvain Gugger's avatar
Sylvain Gugger committed
17
import random
18
import tempfile
19
from typing import List, Tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24

import numpy as np

import transformers
from transformers import is_flax_available, is_torch_available
Daniel Stancl's avatar
Daniel Stancl committed
25
from transformers.models.auto import get_values
26
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
29
30
31
32
33


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
34
    import jaxlib.xla_extension as jax_xla
Daniel Stancl's avatar
Daniel Stancl committed
35
    from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
36
37
38
39
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43
44
45
46

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


Daniel Stancl's avatar
Daniel Stancl committed
47
48
49
50
51
52
53
54
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
        if "_range" in key or "_std" in key or "initializer_factor" in key:
            setattr(configs_no_init, key, 1e-10)
    return configs_no_init


Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


Suraj Patil's avatar
Suraj Patil committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return np.array(values, dtype=jnp.float32).reshape(shape)


Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
92
93
94
95
def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


96
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()
Daniel Stancl's avatar
Daniel Stancl committed
100
    is_encoder_decoder = False
Sylvain Gugger's avatar
Sylvain Gugger committed
101

102
103
104
105
106
107
108
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
109
                if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
110
111
                else v
                for k, v in inputs_dict.items()
112
113
114
115
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
116
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
117
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
118
119
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assert_almost_equals(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

155
    @is_pt_flax_cross_test
156
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
157
158
159
160
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
161
                # prepare inputs
162
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
163
164
165
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
166
167
168
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

169
                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
170
171
172
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
173
                fx_model = model_class(config, dtype=jnp.float32)
174

175
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
176
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
180

181
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
Daniel Stancl's avatar
Daniel Stancl committed
184
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
Sylvain Gugger's avatar
Sylvain Gugger committed
185

186
187
188
189
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

190
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
191
192
193
194
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
Daniel Stancl's avatar
Daniel Stancl committed
195
196
197
198
                    if not isinstance(
                        fx_output_loaded, tuple
                    ):  # TODO(Patrick, Daniel) - let's discard use_cache for now
                        self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
215
216
217
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
218
219
220
221
222
223
224
225
226
227
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

228
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
229
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
Daniel Stancl's avatar
Daniel Stancl committed
230

231
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
Daniel Stancl's avatar
Daniel Stancl committed
232
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
233
234
235
236
237
238
239
240
241
242
243
244

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
Daniel Stancl's avatar
Daniel Stancl committed
245
246
                    if not isinstance(fx_output, tuple):  # TODO(Patrick, Daniel) - let's discard use_cache for now
                        self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
247
248

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
251
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
252
253
254
            if model_class.__name__ != "FlaxBertModel":
                continue

Sylvain Gugger's avatar
Sylvain Gugger committed
255
            with self.subTest(model_class.__name__):
256
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
257

258
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
259
                outputs = model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
260

261
                # verify that normal save_pretrained works as expected
262
263
264
265
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

266
                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
267
268
269
270
271
272
273
274
275
276
                for output_loaded, output in zip(outputs_loaded, outputs):
                    self.assert_almost_equals(output_loaded, output, 1e-3)

                # verify that save_pretrained for distributed training
                # with `params=params` works as expected
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname, params=model.params)
                    model_loaded = model_class.from_pretrained(tmpdirname)

                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
277
                for output_loaded, output in zip(outputs_loaded, outputs):
278
                    self.assert_almost_equals(output_loaded, output, 1e-3)
279
280
281
282
283
284

    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
285
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
286
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288

                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
289
                def model_jitted(input_ids, attention_mask=None, **kwargs):
290
                    return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
291
292

                with self.subTest("JIT Enabled"):
293
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
294
295
296

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
297
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
298
299
300

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
Daniel Stancl's avatar
Daniel Stancl committed
301

Sylvain Gugger's avatar
Sylvain Gugger committed
302
                    self.assertEqual(jitted_output.shape, output.shape)
303

304
305
306
307
308
309
310
311
312
    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.__call__)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

Daniel Stancl's avatar
Daniel Stancl committed
313
314
315
316
317
318
319
320
321
322
323
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
            else:
                expected_arg_names = ["input_ids", "attention_mask"]
                self.assertListEqual(arg_names[:2], expected_arg_names)
324

325
326
327
328
329
330
331
332
333
334
    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)
335
336
337
338
339
340

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)

            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
341
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
342

Daniel Stancl's avatar
Daniel Stancl committed
343
344
345
346
347
348
349
350
351
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)

            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
            else:
                seq_length = self.model_tester.seq_length
352
353
354
355
356
357

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
            )

Daniel Stancl's avatar
Daniel Stancl committed
358
359
360
361
362
363
364
365
366
367
368
369
370
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

371
372
373
374
375
376
377
378
379
380
381
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)
382
383
384
385
386
387

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        seq_length = getattr(self.model_tester, "seq_length", None)
Daniel Stancl's avatar
Daniel Stancl committed
388
389
390
391
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
392
393
394
395
396
397

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
398
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
399
400
401
402
403
404
405
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
406
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
407
408
409
410
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
411
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
412
413
414
            )
            out_len = len(outputs)

Daniel Stancl's avatar
Daniel Stancl committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            if self.is_encoder_decoder:
                correct_outlen = 5

                # Question Answering model returns start_logits and end_logits
                if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output

                self.assertEqual(out_len, correct_outlen)

                # decoder attentions
                decoder_attentions = outputs.decoder_attentions
                self.assertIsInstance(decoder_attentions, (list, tuple))
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
                )

                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

446
447
448
449
450
451
            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))

Daniel Stancl's avatar
Daniel Stancl committed
452
453
454
455
456
457
            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
458
459
460
461
462
463
464
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(self_attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
465
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
466
            )