test_modeling_flax_common.py 13.7 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import inspect
Sylvain Gugger's avatar
Sylvain Gugger committed
17
import random
18
import tempfile
19
from typing import List, Tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
22
23
24

import numpy as np

import transformers
from transformers import is_flax_available, is_torch_available
25
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
28
29
30
31
32


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
33
    import jaxlib.xla_extension as jax_xla
34
35
36
37
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


70
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
71
72
73
74
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()

75
76
77
78
79
80
81
82
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
                for k, v in inputs_dict.items()
83
                if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
84
85
86
87
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
88
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
89
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assert_almost_equals(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

127
    @is_pt_flax_cross_test
128
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
129
130
131
132
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
133
                # prepare inputs
134
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
135
136
137
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
138
139
140
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

141
                pt_model = pt_model_class(config).eval()
142
                fx_model = model_class(config, dtype=jnp.float32)
143

144
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
145
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
146
147
148

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
149

150
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
151
152
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
153
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
Sylvain Gugger's avatar
Sylvain Gugger committed
154

155
156
157
158
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

159
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
160
161
162
163
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
164
                    self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

191
                fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
192
193
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
194
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
195
196
197
198
199
200
201
202
203
204
205
206

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
207
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
208
209

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
210
211
212
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
213
214
215
            if model_class.__name__ != "FlaxBertModel":
                continue

Sylvain Gugger's avatar
Sylvain Gugger committed
216
            with self.subTest(model_class.__name__):
217
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
218

219
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
220
                outputs = model(**prepared_inputs_dict).to_tuple()
Sylvain Gugger's avatar
Sylvain Gugger committed
221

222
                # verify that normal save_pretrained works as expected
223
224
225
226
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

227
                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
228
229
230
231
232
233
234
235
236
237
                for output_loaded, output in zip(outputs_loaded, outputs):
                    self.assert_almost_equals(output_loaded, output, 1e-3)

                # verify that save_pretrained for distributed training
                # with `params=params` works as expected
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname, params=model.params)
                    model_loaded = model_class.from_pretrained(tmpdirname)

                outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
238
                for output_loaded, output in zip(outputs_loaded, outputs):
239
                    self.assert_almost_equals(output_loaded, output, 1e-3)
240
241
242
243
244
245

    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
246
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
247
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
248
249

                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
250
251
                def model_jitted(input_ids, attention_mask=None, **kwargs):
                    return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
252
253
254

                with self.subTest("JIT Enabled"):
                    jitted_outputs = model_jitted(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
255
256
257

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
258
                        outputs = model_jitted(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
259
260
261
262

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)
263

264
                @jax.jit
Suraj Patil's avatar
Suraj Patil committed
265
                def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
266
267
268
                    return model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
Suraj Patil's avatar
Suraj Patil committed
269
                        **kwargs,
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                    )

                # jitted function cannot return OrderedDict
                with self.assertRaises(TypeError):
                    model_jitted_return_dict(**prepared_inputs_dict)

    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.__call__)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            expected_arg_names = ["input_ids", "attention_mask"]
            self.assertListEqual(arg_names[:2], expected_arg_names)

288
289
290
291
292
293
294
295
296
297
    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    def test_hidden_states_output(self):
        def check_hidden_states_output(inputs_dict, config, model_class):
            model = model_class(config)

            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
            hidden_states = outputs.hidden_states

            self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
            seq_length = self.model_tester.seq_length

            self.assertListEqual(
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
            )

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)