test_onnx.py 7.61 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16
from pathlib import Path
17
from tempfile import NamedTemporaryFile, TemporaryDirectory
18
19

from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
20
21
22
23
24
25
26
from transformers.convert_graph_to_onnx import (
    convert,
    ensure_valid_input,
    generate_identified_filename,
    infer_shapes,
    quantize,
)
27
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class FuncContiguousArgs:
    def forward(self, input_ids, token_type_ids, attention_mask):
        return None


class FuncNonContiguousArgs:
    def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):
        return None


class OnnxExportTestCase(unittest.TestCase):
    MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]

    @require_tf
44
    @slow
45
46
    def test_export_tensorflow(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
47
            self._test_export(model, "tf", 12)
48
49

    @require_torch
50
    @slow
51
52
    def test_export_pytorch(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
53
            self._test_export(model, "pt", 12)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
    @require_torch
    @slow
    def test_export_custom_bert_model(self):
        from transformers import BertModel

        vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
        with NamedTemporaryFile(mode="w+t") as vocab_file:
            vocab_file.write("\n".join(vocab))
            vocab_file.flush()
            tokenizer = BertTokenizerFast(vocab_file.name)

        with TemporaryDirectory() as bert_save_dir:
            model = BertModel(BertConfig(vocab_size=len(vocab)))
            model.save_pretrained(bert_save_dir)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            self._test_export(bert_save_dir, "pt", 12, tokenizer)

    @require_tf
    @slow
    def test_quantize_tf(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "tf", 12)
            quantized_path = quantize(Path(path))

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")

    @require_torch
    @slow
    def test_quantize_pytorch(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "pt", 12)
87
            quantized_path = quantize(path)
88
89
90
91

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")
92
93

    def _test_export(self, model, framework, opset, tokenizer=None):
94
95
        try:
            # Compute path
96
            with TemporaryDirectory() as tempdir:
97
                path = Path(tempdir).joinpath("model.onnx")
98
99

            # Remove folder if exists
100
101
            if path.parent.exists():
                path.parent.rmdir()
102

103
104
            # Export
            convert(framework, model, path, opset, tokenizer)
105

106
            return path
107
108
109
110
        except Exception as e:
            self.fail(e)

    @require_torch
111
    @require_tokenizers
Lysandre Debut's avatar
Lysandre Debut committed
112
    @slow
113
114
115
116
117
118
    def test_infer_dynamic_axis_pytorch(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import BertModel

Lysandre Debut's avatar
Lysandre Debut committed
119
120
        model = BertModel(BertConfig.from_pretrained("lysandre/tiny-bert-random"))
        tokenizer = BertTokenizerFast.from_pretrained("lysandre/tiny-bert-random")
121
122
123
        self._test_infer_dynamic_axis(model, tokenizer, "pt")

    @require_tf
124
    @require_tokenizers
Lysandre Debut's avatar
Lysandre Debut committed
125
    @slow
126
127
128
129
130
131
    def test_infer_dynamic_axis_tf(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import TFBertModel

Lysandre Debut's avatar
Lysandre Debut committed
132
133
        model = TFBertModel(BertConfig.from_pretrained("lysandre/tiny-bert-random"))
        tokenizer = BertTokenizerFast.from_pretrained("lysandre/tiny-bert-random")
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self._test_infer_dynamic_axis(model, tokenizer, "tf")

    def _test_infer_dynamic_axis(self, model, tokenizer, framework):
        nlp = FeatureExtractionPipeline(model, tokenizer)

        variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
        input_vars, output_vars, shapes, tokens = infer_shapes(nlp, framework)

        # Assert all variables are present
        self.assertEqual(len(shapes), len(variable_names))
        self.assertTrue(all([var_name in shapes for var_name in variable_names]))
        self.assertSequenceEqual(variable_names[:3], input_vars)
        self.assertSequenceEqual(variable_names[3:], output_vars)

        # Assert inputs are {0: batch, 1: sequence}
        for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
            self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})

        # Assert outputs are {0: batch, 1: sequence} and {0: batch}
        self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
        self.assertDictEqual(shapes["output_1"], {0: "batch"})

    def test_ensure_valid_input(self):
        """
        Validate parameters are correctly exported
        GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
        ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
        token_type_ids and attention_mask for now to not having a None tensor in the middle
        """
        # All generated args are valid
        input_names = ["input_ids", "attention_mask", "token_type_ids"]
        tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
166
        ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
167
168
169
170

        # Should have exactly the same number of args (all are valid)
        self.assertEqual(len(inputs_args), 3)

171
172
173
        # Should have exactly the same input names
        self.assertEqual(set(ordered_input_names), set(input_names))

174
175
176
177
178
        # Parameter should be reordered according to their respective place in the function:
        # (input_ids, token_type_ids, attention_mask)
        self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))

        # Generated args are interleaved with another args (for instance parameter "past" in GPT2)
179
        ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
180
181
182

        # Should have exactly the one arg (all before the one not provided "some_other_args")
        self.assertEqual(len(inputs_args), 1)
183
        self.assertEqual(len(ordered_input_names), 1)
184
185
186

        # Should have only "input_ids"
        self.assertEqual(inputs_args[0], tokens["input_ids"])
187
        self.assertEqual(ordered_input_names[0], "input_ids")
188
189
190
191

    def test_generate_identified_name(self):
        generated = generate_identified_filename(Path("/home/something/my_fake_model.onnx"), "-test")
        self.assertEqual("/home/something/my_fake_model-test.onnx", generated.as_posix())