"docs/source/locales/zh/LC_MESSAGES/sphinx.po" did not exist on "f5b89bb6552e466cbcbd62147f180cd4124b58ab"
test_onnx.py 7.91 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16
from pathlib import Path
17
from tempfile import NamedTemporaryFile, TemporaryDirectory
18
19

from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
20
21
22
23
24
25
26
from transformers.convert_graph_to_onnx import (
    convert,
    ensure_valid_input,
    generate_identified_filename,
    infer_shapes,
    quantize,
)
27
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
28
29
30
31
32
33
34
35
36
37
38
39
40


class FuncContiguousArgs:
    def forward(self, input_ids, token_type_ids, attention_mask):
        return None


class FuncNonContiguousArgs:
    def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):
        return None


class OnnxExportTestCase(unittest.TestCase):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
41
42
43
44
45
    MODEL_TO_TEST = [
        # (model_name, model_kwargs)
        ("bert-base-cased", {}),
        ("gpt2", {"use_cache": False}),  # We don't support exporting GPT2 past keys anymore
    ]
46
47

    @require_tf
48
    @slow
49
    def test_export_tensorflow(self):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
50
51
        for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
            self._test_export(model, "tf", 12, **model_kwargs)
52
53

    @require_torch
54
    @slow
55
    def test_export_pytorch(self):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
56
57
        for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
            self._test_export(model, "pt", 12, **model_kwargs)
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
    @require_torch
    @slow
    def test_export_custom_bert_model(self):
        from transformers import BertModel

        vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
        with NamedTemporaryFile(mode="w+t") as vocab_file:
            vocab_file.write("\n".join(vocab))
            vocab_file.flush()
            tokenizer = BertTokenizerFast(vocab_file.name)

        with TemporaryDirectory() as bert_save_dir:
            model = BertModel(BertConfig(vocab_size=len(vocab)))
            model.save_pretrained(bert_save_dir)
73
74
75
76
77
            self._test_export(bert_save_dir, "pt", 12, tokenizer)

    @require_tf
    @slow
    def test_quantize_tf(self):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
78
79
        for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "tf", 12, **model_kwargs)
80
81
82
83
84
85
86
87
88
            quantized_path = quantize(Path(path))

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")

    @require_torch
    @slow
    def test_quantize_pytorch(self):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
89
90
        for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "pt", 12, **model_kwargs)
91
            quantized_path = quantize(path)
92
93
94
95

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")
96

Funtowicz Morgan's avatar
Funtowicz Morgan committed
97
    def _test_export(self, model, framework, opset, tokenizer=None, **model_kwargs):
98
99
        try:
            # Compute path
100
            with TemporaryDirectory() as tempdir:
101
                path = Path(tempdir).joinpath("model.onnx")
102
103

            # Remove folder if exists
104
105
            if path.parent.exists():
                path.parent.rmdir()
106

107
            # Export
Funtowicz Morgan's avatar
Funtowicz Morgan committed
108
            convert(framework, model, path, opset, tokenizer, **model_kwargs)
109

110
            return path
111
112
113
114
        except Exception as e:
            self.fail(e)

    @require_torch
115
    @require_tokenizers
Lysandre Debut's avatar
Lysandre Debut committed
116
    @slow
117
118
119
120
121
122
    def test_infer_dynamic_axis_pytorch(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import BertModel

Lysandre Debut's avatar
Lysandre Debut committed
123
124
        model = BertModel(BertConfig.from_pretrained("lysandre/tiny-bert-random"))
        tokenizer = BertTokenizerFast.from_pretrained("lysandre/tiny-bert-random")
125
126
127
        self._test_infer_dynamic_axis(model, tokenizer, "pt")

    @require_tf
128
    @require_tokenizers
Lysandre Debut's avatar
Lysandre Debut committed
129
    @slow
130
131
132
133
134
135
    def test_infer_dynamic_axis_tf(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import TFBertModel

Lysandre Debut's avatar
Lysandre Debut committed
136
137
        model = TFBertModel(BertConfig.from_pretrained("lysandre/tiny-bert-random"))
        tokenizer = BertTokenizerFast.from_pretrained("lysandre/tiny-bert-random")
138
139
140
        self._test_infer_dynamic_axis(model, tokenizer, "tf")

    def _test_infer_dynamic_axis(self, model, tokenizer, framework):
141
        feature_extractor = FeatureExtractionPipeline(model, tokenizer)
142
143

        variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
144
        input_vars, output_vars, shapes, tokens = infer_shapes(feature_extractor, framework)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        # Assert all variables are present
        self.assertEqual(len(shapes), len(variable_names))
        self.assertTrue(all([var_name in shapes for var_name in variable_names]))
        self.assertSequenceEqual(variable_names[:3], input_vars)
        self.assertSequenceEqual(variable_names[3:], output_vars)

        # Assert inputs are {0: batch, 1: sequence}
        for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
            self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})

        # Assert outputs are {0: batch, 1: sequence} and {0: batch}
        self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
        self.assertDictEqual(shapes["output_1"], {0: "batch"})

    def test_ensure_valid_input(self):
        """
        Validate parameters are correctly exported
        GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
        ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
        token_type_ids and attention_mask for now to not having a None tensor in the middle
        """
        # All generated args are valid
        input_names = ["input_ids", "attention_mask", "token_type_ids"]
        tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
170
        ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
171
172
173
174

        # Should have exactly the same number of args (all are valid)
        self.assertEqual(len(inputs_args), 3)

175
176
177
        # Should have exactly the same input names
        self.assertEqual(set(ordered_input_names), set(input_names))

178
179
180
181
182
        # Parameter should be reordered according to their respective place in the function:
        # (input_ids, token_type_ids, attention_mask)
        self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))

        # Generated args are interleaved with another args (for instance parameter "past" in GPT2)
183
        ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
184
185
186

        # Should have exactly the one arg (all before the one not provided "some_other_args")
        self.assertEqual(len(inputs_args), 1)
187
        self.assertEqual(len(ordered_input_names), 1)
188
189
190

        # Should have only "input_ids"
        self.assertEqual(inputs_args[0], tokens["input_ids"])
191
        self.assertEqual(ordered_input_names[0], "input_ids")
192
193
194
195

    def test_generate_identified_name(self):
        generated = generate_identified_filename(Path("/home/something/my_fake_model.onnx"), "-test")
        self.assertEqual("/home/something/my_fake_model-test.onnx", generated.as_posix())