test_onnx_v2.py 14 KB
Newer Older
1
2
3
4
5
from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
from unittest.mock import patch

lewtun's avatar
lewtun committed
6
7
import pytest

8
from parameterized import parameterized
lewtun's avatar
lewtun committed
9
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
10
11
12
from transformers.onnx import (
    EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
    OnnxConfig,
lewtun's avatar
lewtun committed
13
    OnnxConfigWithPast,
14
15
16
17
    ParameterFormat,
    export,
    validate_model_outputs,
)
Jim Rohrer's avatar
Jim Rohrer committed
18
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
19
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
20
21


22
if is_torch_available() or is_tf_available():
23
24
    from transformers.onnx.features import FeaturesManager

25
26
27
28
29
30
31

@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
    """
    Cover all the utilities involved to export ONNX models
    """

32
33
34
35
36
37
38
39
40
    @require_torch
    @patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
    def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
        """
        Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
        """
        self.assertRaises(AssertionError, export, None, None, None, None, None)
        mock_is_torch_onnx_dict_inputs_support_available.assert_called()

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    def test_compute_effective_axis_dimension(self):
        """
        When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
        We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
        (> 1 to avoid ONNX squeezing the axis).

        This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
        """

        # Dynamic axis (batch, no token added by the tokenizer)
        self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2)

        # Static axis (batch, no token added by the tokenizer)
        self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=2, num_token_to_add=0), 2)

        # Dynamic axis (sequence, token added by the tokenizer 2 (no pair))
        self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
        self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)

        # Dynamic axis (sequence, token added by the tokenizer 3 (pair))
        self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
        self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)

    def test_compute_parameters_serialized_size(self):
        """
        This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
        parameters for the specified parameter's dtype.
        """
        self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size)

    def test_flatten_output_collection_property(self):
        """
        This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
        past_keys = Tuple[Tuple]

        ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
        """
        self.assertEqual(
79
            OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]),
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            {
                "past_key.0": 0,
                "past_key.1": 1,
                "past_key.2": 2,
            },
        )


class OnnxConfigTestCaseV2(TestCase):
    """
    Cover the test for models default.

    Default means no specific features is being enabled on the model.
    """

    @patch.multiple(OnnxConfig, __abstractmethods__=set())
    def test_use_external_data_format(self):
        """
        External data format is required only if the serialized size of the parameters if bigger than 2Gb
        """
        TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT

        # No parameters
        self.assertFalse(OnnxConfig.use_external_data_format(0))

        # Some parameters
        self.assertFalse(OnnxConfig.use_external_data_format(1))

        # Almost 2Gb parameters
        self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size))

        # Exactly 2Gb parameters
        self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT))

        # More than 2Gb parameters
        self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))


class OnnxConfigWithPastTestCaseV2(TestCase):
    """
    Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
    """

123
124
125
126
127
128
    SUPPORTED_WITH_PAST_CONFIGS = {}
    # SUPPORTED_WITH_PAST_CONFIGS = {
    #     ("BART", BartConfig),
    #     ("GPT2", GPT2Config),
    #     # ("T5", T5Config)
    # }
129
130
131
132
133
134
135
136
137

    @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
    def test_use_past(self):
        """
        Ensure the use_past variable is correctly being set
        """
        for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
            with self.subTest(name):
                self.assertFalse(
138
139
                    OnnxConfigWithPast.from_model_config(config()).use_past,
                    "OnnxConfigWithPast.from_model_config() should not use_past",
140
141
142
                )

                self.assertTrue(
143
144
                    OnnxConfigWithPast.with_past(config()).use_past,
                    "OnnxConfigWithPast.from_model_config() should use_past",
145
146
147
148
149
150
151
152
153
154
155
                )

    @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
    def test_values_override(self):
        """
        Ensure the use_past variable correctly set the `use_cache` value in model's configuration
        """
        for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
            with self.subTest(name):

                # without past
156
                onnx_config_default = OnnxConfigWithPast.from_model_config(config())
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
                self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
                self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
                self.assertFalse(
                    onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
                )

                # with past
                onnx_config_default = OnnxConfigWithPast.with_past(config())
                self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
                self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
                self.assertTrue(
                    onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
                )


172
173
174
PYTORCH_EXPORT_MODELS = {
    ("albert", "hf-internal-testing/tiny-albert"),
    ("bert", "bert-base-cased"),
175
    ("big-bird", "google/bigbird-roberta-base"),
176
    ("ibert", "kssteven/ibert-roberta-base"),
177
    ("camembert", "camembert-base"),
178
    ("convbert", "YituTech/conv-bert-base"),
179
    ("distilbert", "distilbert-base-cased"),
180
    ("electra", "google/electra-base-generator"),
181
    ("roberta", "roberta-base"),
182
    ("roformer", "junnyu/roformer_chinese_base"),
183
    ("mobilebert", "google/mobilebert-uncased"),
184
185
    ("xlm-roberta", "xlm-roberta-base"),
    ("layoutlm", "microsoft/layoutlm-base-uncased"),
lewtun's avatar
lewtun committed
186
    ("vit", "google/vit-base-patch16-224"),
187
    ("deit", "facebook/deit-small-patch16-224"),
Jim Rohrer's avatar
Jim Rohrer committed
188
    ("beit", "microsoft/beit-base-patch16-224"),
189
    ("data2vec-text", "facebook/data2vec-text-base"),
190
191
192
193
194
195
196
197
198
199
200
}

PYTORCH_EXPORT_WITH_PAST_MODELS = {
    ("gpt2", "gpt2"),
    ("gpt-neo", "EleutherAI/gpt-neo-125M"),
}

PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
    ("bart", "facebook/bart-base"),
    ("mbart", "sshleifer/tiny-mbart"),
    ("t5", "t5-small"),
201
    ("marian", "Helsinki-NLP/opus-mt-en-de"),
202
    ("m2m-100", "facebook/m2m100_418M"),
203
204
    ("blenderbot-small", "facebook/blenderbot_small-90M"),
    ("blenderbot", "facebook/blenderbot-400M-distill"),
205
    ("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
206
207
}

208
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
209
210
211
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
    ("albert", "hf-internal-testing/tiny-albert"),
    ("bert", "bert-base-cased"),
212
    ("camembert", "camembert-base"),
213
214
215
216
    ("distilbert", "distilbert-base-cased"),
    ("roberta", "roberta-base"),
}

217
218
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {}
219

220
221
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
222

223
224
225

def _get_models_to_test(export_models_list):
    models_to_test = []
226
    if is_torch_available() or is_tf_available():
Sylvain Gugger's avatar
Sylvain Gugger committed
227
        for name, model in export_models_list:
228
229
230
231
232
233
234
235
            for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
                name
            ).items():
                models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
        return sorted(models_to_test)
    else:
        # Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
        # decorators.
236
237
        # The reason for not returning an empty list is because parameterized.expand complains when it's empty.
        return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
238
239
240
241
242
243
244


class OnnxExportTestCaseV2(TestCase):
    """
    Integration tests ensuring supported models are correctly exported
    """

245
    def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
246
247
        from transformers.onnx import export

248
        model_class = FeaturesManager.get_model_class_for_feature(feature)
lewtun's avatar
lewtun committed
249
        config = AutoConfig.from_pretrained(model_name)
250
251
        model = model_class.from_config(config)
        onnx_config = onnx_config_class_constructor(model.config)
252

lewtun's avatar
lewtun committed
253
        if is_torch_available():
254
            from transformers.utils import torch_version
lewtun's avatar
lewtun committed
255
256
257

            if torch_version < onnx_config.torch_onnx_minimum_version:
                pytest.skip(
Sylvain Gugger's avatar
Sylvain Gugger committed
258
259
                    "Skipping due to incompatible PyTorch version. Minimum required is"
                    f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
lewtun's avatar
lewtun committed
260
261
262
263
264
265
266
267
268
269
270
271
272
                )

        # Check the modality of the inputs and instantiate the appropriate preprocessor
        if model.main_input_name == "input_ids":
            preprocessor = AutoTokenizer.from_pretrained(model_name)
            # Useful for causal lm models that do not use pad tokens.
            if not getattr(config, "pad_token_id", None):
                config.pad_token_id = preprocessor.eos_token_id
        elif model.main_input_name == "pixel_values":
            preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
        else:
            raise ValueError(f"Unsupported model input name: {model.main_input_name}")

273
274
275
        with NamedTemporaryFile("w") as output:
            try:
                onnx_inputs, onnx_outputs = export(
lewtun's avatar
lewtun committed
276
                    preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
277
278
279
                )
                validate_model_outputs(
                    onnx_config,
lewtun's avatar
lewtun committed
280
                    preprocessor,
281
282
283
284
285
286
287
                    model,
                    Path(output.name),
                    onnx_outputs,
                    onnx_config.atol_for_validation,
                )
            except (RuntimeError, ValueError) as e:
                self.fail(f"{name}, {feature} -> {e}")
288

289
    @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
290
291
    @slow
    @require_torch
lewtun's avatar
lewtun committed
292
    @require_vision
293
    @require_rjieba
294
    def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
295
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
296

297
298
299
300
    @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
    @slow
    @require_torch
    def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
301
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
302

303
304
305
306
307
308
    @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
    @slow
    @require_torch
    def test_pytorch_export_seq2seq_with_past(
        self, test_name, name, model_name, feature, onnx_config_class_constructor
    ):
309
310
311
312
313
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)

    @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
    @slow
    @require_tf
lewtun's avatar
lewtun committed
314
    @require_vision
315
316
317
    def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)

318
    @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
319
320
321
322
323
    @slow
    @require_tf
    def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)

324
    @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
325
326
327
328
329
330
    @slow
    @require_tf
    def test_tensorflow_export_seq2seq_with_past(
        self, test_name, name, model_name, feature, onnx_config_class_constructor
    ):
        self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)