test_modeling_flax_common.py 25.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import inspect
Sylvain Gugger's avatar
Sylvain Gugger committed
17
import random
18
import tempfile
19
import unittest
20
from typing import List, Tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
21
22
23
24

import numpy as np

import transformers
25
26
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
27
from transformers import BertConfig, is_flax_available, is_torch_available
Daniel Stancl's avatar
Daniel Stancl committed
28
from transformers.models.auto import get_values
29
30
31
32
from transformers.testing_utils import (
    ENDPOINT_STAGING,
    PASS,
    USER,
33
    CaptureLogger,
34
35
36
37
38
    is_pt_flax_cross_test,
    is_staging_test,
    require_flax,
    slow,
)
39
from transformers.utils import logging
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43
44
45
46


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
47
48
    from flax.core.frozen_dict import unfreeze
    from flax.traverse_util import flatten_dict
49
50
51
52
53
54
55
    from transformers import (
        FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
        FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        FLAX_MODEL_MAPPING,
        FlaxAutoModelForSequenceClassification,
        FlaxBertModel,
    )
56
57
58
59
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
60
61
62
63
64
65
66

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


Daniel Stancl's avatar
Daniel Stancl committed
67
68
69
70
71
72
73
74
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
        if "_range" in key or "_std" in key or "initializer_factor" in key:
            setattr(configs_no_init, key, 1e-10)
    return configs_no_init


Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


Suraj Patil's avatar
Suraj Patil committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return np.array(values, dtype=jnp.float32).reshape(shape)


Sylvain Gugger's avatar
Sylvain Gugger committed
109
110
111
112
113
114
115
def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


116
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
117
118
119
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()
Daniel Stancl's avatar
Daniel Stancl committed
120
    is_encoder_decoder = False
Sylvain Gugger's avatar
Sylvain Gugger committed
121

122
123
124
125
126
127
128
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
129
                if isinstance(v, (jnp.ndarray, np.ndarray))
130
131
                else v
                for k, v in inputs_dict.items()
132
133
134
135
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
136
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
137
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
138
139
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assert_almost_equals(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

175
    @is_pt_flax_cross_test
176
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179
180
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
181
                # prepare inputs
182
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
183
184
185
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
186
187
188
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

189
                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
190
191
192
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
193
                fx_model = model_class(config, dtype=jnp.float32)
194

195
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
196
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
197
198
199

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
200

201
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
202
203
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
204
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Sylvain Gugger's avatar
Sylvain Gugger committed
205

206
207
208
209
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

210
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
211
212
213
214
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
215
                    self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
232
233
234
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
235
236
237
238
239
240
241
242
243
244
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

245
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
246
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
Daniel Stancl's avatar
Daniel Stancl committed
247

248
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
249
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
250
251
252
253
254
255
256
257
258
259
260
261

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
262
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
263
264

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
265
266
267
268
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
269
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
270

271
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
272
                outputs = model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
273

274
                # verify that normal save_pretrained works as expected
275
276
277
278
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

279
                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
280
281
282
283
284
285
286
287
288
289
                for output_loaded, output in zip(outputs_loaded, outputs):
                    self.assert_almost_equals(output_loaded, output, 1e-3)

                # verify that save_pretrained for distributed training
                # with `params=params` works as expected
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname, params=model.params)
                    model_loaded = model_class.from_pretrained(tmpdirname)

                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
290
                for output_loaded, output in zip(outputs_loaded, outputs):
291
                    self.assert_almost_equals(output_loaded, output, 1e-3)
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def test_save_load_from_base(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = base_class(config)
            base_params = flatten_dict(unfreeze(model.params))

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                head_model = model_class.from_pretrained(tmpdirname)

                base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))

                for key in base_param_from_head.keys():
                    max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

    def test_save_load_to_base(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

Vasudev Gupta's avatar
Vasudev Gupta committed
337
    @slow
338
339
340
341
342
    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
343
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
344
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
345
346

                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
347
                def model_jitted(input_ids, attention_mask=None, **kwargs):
348
                    return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
349
350

                with self.subTest("JIT Enabled"):
351
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
352
353
354

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
355
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
356
357
358

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
Daniel Stancl's avatar
Daniel Stancl committed
359

Sylvain Gugger's avatar
Sylvain Gugger committed
360
                    self.assertEqual(jitted_output.shape, output.shape)
361

362
363
364
365
366
367
368
369
370
    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.__call__)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

Daniel Stancl's avatar
Daniel Stancl committed
371
372
373
374
375
376
377
378
379
380
381
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
            else:
                expected_arg_names = ["input_ids", "attention_mask"]
                self.assertListEqual(arg_names[:2], expected_arg_names)
382

383
384
385
386
387
388
389
390
391
392
    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)
393
394
395
396
397
398

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)

            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
399
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
400

Daniel Stancl's avatar
Daniel Stancl committed
401
402
403
404
405
406
407
408
409
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)

            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
            else:
                seq_length = self.model_tester.seq_length
410
411
412
413
414
415

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
            )

Daniel Stancl's avatar
Daniel Stancl committed
416
417
418
419
420
421
422
423
424
425
426
427
428
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

429
430
431
432
433
434
435
436
437
438
439
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)
440
441
442
443
444
445

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        seq_length = getattr(self.model_tester, "seq_length", None)
Daniel Stancl's avatar
Daniel Stancl committed
446
447
448
449
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
450
451
452
453
454
455

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
456
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
457
458
459
460
461
462
463
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
464
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
465
466
467
468
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
469
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
470
471
472
            )
            out_len = len(outputs)

Daniel Stancl's avatar
Daniel Stancl committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            if self.is_encoder_decoder:
                correct_outlen = 5

                # Question Answering model returns start_logits and end_logits
                if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output

                self.assertEqual(out_len, correct_outlen)

                # decoder attentions
                decoder_attentions = outputs.decoder_attentions
                self.assertIsInstance(decoder_attentions, (list, tuple))
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
                )

                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

504
505
506
507
508
509
            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))

Daniel Stancl's avatar
Daniel Stancl committed
510
511
512
513
514
515
            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
516
517
518
519
520
521
522
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(self_attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
523
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
524
            )
525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    def test_load_with_mismatched_shapes(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
                continue

            with self.subTest(msg=f"Testing {model_class}"):
                with tempfile.TemporaryDirectory() as tmp_dir:
                    model = model_class(config)
                    model.save_pretrained(tmp_dir)

                    # Fails when we don't set ignore_mismatched_sizes=True
                    with self.assertRaises(ValueError):
                        new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)

                    logger = logging.get_logger("transformers.modeling_flax_utils")
                    with CaptureLogger(logger) as cl:
                        new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
                            tmp_dir, num_labels=42, ignore_mismatched_sizes=True
                        )
                    self.assertIn("the shapes did not match", cl.out)

                    logits = new_model(**inputs_dict)["logits"]
                    self.assertEqual(logits.shape[1], 42)

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

@require_flax
@is_staging_test
class FlaxModelPushToHubTester(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._api = HfApi(endpoint=ENDPOINT_STAGING)
        cls._token = cls._api.login(username=USER, password=PASS)

    @classmethod
    def tearDownClass(cls):
        try:
            cls._api.delete_repo(token=cls._token, name="test-model-flax")
        except HTTPError:
            pass

        try:
            cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
        except HTTPError:
            pass

    def test_push_to_hub(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        model = FlaxBertModel(config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(
                os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
            )

            new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")

            base_params = flatten_dict(unfreeze(model.params))
            new_params = flatten_dict(unfreeze(new_model.params))

            for key in base_params.keys():
                max_diff = (base_params[key] - new_params[key]).sum().item()
                self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

    def test_push_to_hub_in_organization(self):
        config = BertConfig(
            vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
        )
        model = FlaxBertModel(config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(
                os.path.join(tmp_dir, "test-model-flax-org"),
                push_to_hub=True,
                use_auth_token=self._token,
                organization="valid_org",
            )

            new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")

            base_params = flatten_dict(unfreeze(model.params))
            new_params = flatten_dict(unfreeze(new_model.params))

            for key in base_params.keys():
                max_diff = (base_params[key] - new_params[key]).sum().item()
                self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")