"...lm-evaluation-harness.git" did not exist on "616a54032614a1cddd18b1e6bd82855fca39cad6"
test_modeling_flax_common.py 19.9 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import inspect
Sylvain Gugger's avatar
Sylvain Gugger committed
17
import random
18
import tempfile
19
from typing import List, Tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24

import numpy as np

import transformers
from transformers import is_flax_available, is_torch_available
Daniel Stancl's avatar
Daniel Stancl committed
25
from transformers.models.auto import get_values
26
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
29
30
31
32
33


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
34
    import jaxlib.xla_extension as jax_xla
Daniel Stancl's avatar
Daniel Stancl committed
35
    from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
36
37
38
39
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43
44
45
46

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


Daniel Stancl's avatar
Daniel Stancl committed
47
48
49
50
51
52
53
54
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
        if "_range" in key or "_std" in key or "initializer_factor" in key:
            setattr(configs_no_init, key, 1e-10)
    return configs_no_init


Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


Suraj Patil's avatar
Suraj Patil committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return np.array(values, dtype=jnp.float32).reshape(shape)


Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
91
92
93
94
95
def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


96
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()
Daniel Stancl's avatar
Daniel Stancl committed
100
    is_encoder_decoder = False
Sylvain Gugger's avatar
Sylvain Gugger committed
101

102
103
104
105
106
107
108
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
109
                if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
110
111
                else v
                for k, v in inputs_dict.items()
112
113
114
115
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
116
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
117
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
118
119
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assert_almost_equals(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

155
    @is_pt_flax_cross_test
156
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
157
158
159
160
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
161
                # prepare inputs
162
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
163
164
165
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
166
167
168
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

169
                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
170
171
172
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
173
                fx_model = model_class(config, dtype=jnp.float32)
174

175
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
176
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
177
178
179

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
180

181
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
184
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Sylvain Gugger's avatar
Sylvain Gugger committed
185

186
187
188
189
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

190
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
191
192
193
194
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
195
                    self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
Daniel Stancl's avatar
Daniel Stancl committed
212
213
214
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False
215
216
217
218
219
220
221
222
223
224
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

225
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
226
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
Daniel Stancl's avatar
Daniel Stancl committed
227

228
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
229
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
230
231
232
233
234
235
236
237
238
239
240
241

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
242
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
243
244

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
245
246
247
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
248
249
250
            if model_class.__name__ != "FlaxBertModel":
                continue

Sylvain Gugger's avatar
Sylvain Gugger committed
251
            with self.subTest(model_class.__name__):
252
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
253

254
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
255
                outputs = model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
256

257
                # verify that normal save_pretrained works as expected
258
259
260
261
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

262
                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
263
264
265
266
267
268
269
270
271
272
                for output_loaded, output in zip(outputs_loaded, outputs):
                    self.assert_almost_equals(output_loaded, output, 1e-3)

                # verify that save_pretrained for distributed training
                # with `params=params` works as expected
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname, params=model.params)
                    model_loaded = model_class.from_pretrained(tmpdirname)

                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
273
                for output_loaded, output in zip(outputs_loaded, outputs):
274
                    self.assert_almost_equals(output_loaded, output, 1e-3)
275
276
277
278
279
280

    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
281
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
282
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
283
284

                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
285
                def model_jitted(input_ids, attention_mask=None, **kwargs):
286
                    return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
287
288

                with self.subTest("JIT Enabled"):
289
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
290
291
292

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
293
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
294
295
296

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
Daniel Stancl's avatar
Daniel Stancl committed
297

Sylvain Gugger's avatar
Sylvain Gugger committed
298
                    self.assertEqual(jitted_output.shape, output.shape)
299

300
301
302
303
304
305
306
307
308
    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.__call__)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

Daniel Stancl's avatar
Daniel Stancl committed
309
310
311
312
313
314
315
316
317
318
319
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
            else:
                expected_arg_names = ["input_ids", "attention_mask"]
                self.assertListEqual(arg_names[:2], expected_arg_names)
320

321
322
323
324
325
326
327
328
329
330
    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)
331
332
333
334
335
336

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)

            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
337
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
338

Daniel Stancl's avatar
Daniel Stancl committed
339
340
341
342
343
344
345
346
347
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)

            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
            else:
                seq_length = self.model_tester.seq_length
348
349
350
351
352
353

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
            )

Daniel Stancl's avatar
Daniel Stancl committed
354
355
356
357
358
359
360
361
362
363
364
365
366
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

367
368
369
370
371
372
373
374
375
376
377
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)
378
379
380
381
382
383

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        seq_length = getattr(self.model_tester, "seq_length", None)
Daniel Stancl's avatar
Daniel Stancl committed
384
385
386
387
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
388
389
390
391
392
393

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
394
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
395
396
397
398
399
400
401
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Daniel Stancl's avatar
Daniel Stancl committed
402
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
403
404
405
406
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
407
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
408
409
410
            )
            out_len = len(outputs)

Daniel Stancl's avatar
Daniel Stancl committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            if self.is_encoder_decoder:
                correct_outlen = 5

                # Question Answering model returns start_logits and end_logits
                if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output

                self.assertEqual(out_len, correct_outlen)

                # decoder attentions
                decoder_attentions = outputs.decoder_attentions
                self.assertIsInstance(decoder_attentions, (list, tuple))
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
                )

                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

442
443
444
445
446
447
            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))

Daniel Stancl's avatar
Daniel Stancl committed
448
449
450
451
452
453
            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
454
455
456
457
458
459
460
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)

            self.assertListEqual(
                list(self_attentions[0].shape[-3:]),
Daniel Stancl's avatar
Daniel Stancl committed
461
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
462
            )