test_modeling_flax_common.py 9 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
Sylvain Gugger's avatar
Sylvain Gugger committed
16
import random
17
import tempfile
Sylvain Gugger's avatar
Sylvain Gugger committed
18
19
20
21
22

import numpy as np

import transformers
from transformers import is_flax_available, is_torch_available
23
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
24
25
26
27
28
29
30


if is_flax_available():
    import os

    import jax
    import jax.numpy as jnp
31
32
33
34
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

if is_torch_available():
    import torch


def ids_tensor(shape, vocab_size, rng=None):
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    output = np.array(values, dtype=jnp.int32).reshape(shape)

    return output


def random_attention_mask(shape, rng=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


67
@require_flax
Sylvain Gugger's avatar
Sylvain Gugger committed
68
69
70
71
class FlaxModelTesterMixin:
    model_tester = None
    all_model_classes = ()

72
73
74
75
76
77
78
79
80
81
82
83
    def _prepare_for_class(self, inputs_dict, model_class):
        inputs_dict = copy.deepcopy(inputs_dict)

        # hack for now until we have AutoModel classes
        if "ForMultipleChoice" in model_class.__name__:
            inputs_dict = {
                k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
                for k, v in inputs_dict.items()
            }

        return inputs_dict

Sylvain Gugger's avatar
Sylvain Gugger committed
84
    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
85
        diff = np.abs((a - b)).max()
Sylvain Gugger's avatar
Sylvain Gugger committed
86
87
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

88
    @is_pt_flax_cross_test
89
    def test_equivalence_pt_to_flax(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
93
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
94
                # prepare inputs
95
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
96
97
98
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
101
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

102
                pt_model = pt_model_class(config).eval()
103
                fx_model = model_class(config, dtype=jnp.float32)
104

105
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
106
                fx_model.params = fx_state
Sylvain Gugger's avatar
Sylvain Gugger committed
107
108
109

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()
110

111
                fx_outputs = fx_model(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
112
113
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
114
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
Sylvain Gugger's avatar
Sylvain Gugger committed
115

116
117
118
119
                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)

120
                fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
121
122
123
124
                self.assertEqual(
                    len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                    self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                # prepare inputs
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}

                # load corresponding PyTorch class
                pt_model_class_name = model_class.__name__[4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)

                pt_model = pt_model_class(config).eval()
                fx_model = model_class(config, dtype=jnp.float32)

                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs).to_tuple()

                fx_outputs = fx_model(**prepared_inputs_dict)
                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
                for fx_output, pt_output in zip(fx_outputs, pt_outputs):
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
                    pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)

                with torch.no_grad():
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()

                self.assertEqual(
                    len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
                )
                for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
169
170

    def test_from_pretrained_save_pretrained(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
171
172
173
174
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
175
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
176

177
178
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
                outputs = model(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
179

180
181
182
183
                with tempfile.TemporaryDirectory() as tmpdirname:
                    model.save_pretrained(tmpdirname)
                    model_loaded = model_class.from_pretrained(tmpdirname)

184
                outputs_loaded = model_loaded(**prepared_inputs_dict)
185
                for output_loaded, output in zip(outputs_loaded, outputs):
186
                    self.assert_almost_equals(output_loaded, output, 1e-3)
187
188
189
190
191
192

    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
193
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
194
                model = model_class(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
195
196
197
198
199
200
201

                @jax.jit
                def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
                    return model(input_ids, attention_mask, token_type_ids)

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
202
                        outputs = model_jitted(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
203
204

                with self.subTest("JIT Enabled"):
205
                    jitted_outputs = model_jitted(**prepared_inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
206
207
208
209

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)
210
211
212
213
214
215
216
217
218
219
220

    def test_naming_convention(self):
        for model_class in self.all_model_classes:
            model_class_name = model_class.__name__
            module_class_name = (
                model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
            )
            bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
            module_cls = getattr(bert_modeling_flax_module, module_class_name)

            self.assertIsNotNone(module_cls)