test_onnx.py 6.9 KB
Newer Older
1
import unittest
2
from pathlib import Path
3
from tempfile import NamedTemporaryFile, TemporaryDirectory
4
5

from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
6
7
8
9
10
11
12
from transformers.convert_graph_to_onnx import (
    convert,
    ensure_valid_input,
    generate_identified_filename,
    infer_shapes,
    quantize,
)
13
from transformers.testing_utils import require_tf, require_torch, slow
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class FuncContiguousArgs:
    def forward(self, input_ids, token_type_ids, attention_mask):
        return None


class FuncNonContiguousArgs:
    def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):
        return None


class OnnxExportTestCase(unittest.TestCase):
    MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]

    @require_tf
30
    @slow
31
32
    def test_export_tensorflow(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
33
            self._test_export(model, "tf", 12)
34
35

    @require_torch
36
    @slow
37
38
    def test_export_pytorch(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
39
            self._test_export(model, "pt", 12)
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
    @require_torch
    @slow
    def test_export_custom_bert_model(self):
        from transformers import BertModel

        vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
        with NamedTemporaryFile(mode="w+t") as vocab_file:
            vocab_file.write("\n".join(vocab))
            vocab_file.flush()
            tokenizer = BertTokenizerFast(vocab_file.name)

        with TemporaryDirectory() as bert_save_dir:
            model = BertModel(BertConfig(vocab_size=len(vocab)))
            model.save_pretrained(bert_save_dir)
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            self._test_export(bert_save_dir, "pt", 12, tokenizer)

    @require_tf
    @slow
    def test_quantize_tf(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "tf", 12)
            quantized_path = quantize(Path(path))

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")

    @require_torch
    @slow
    def test_quantize_pytorch(self):
        for model in OnnxExportTestCase.MODEL_TO_TEST:
            path = self._test_export(model, "pt", 12)
73
            quantized_path = quantize(path)
74
75
76
77

            # Ensure the actual quantized model is not bigger than the original one
            if quantized_path.stat().st_size >= Path(path).stat().st_size:
                self.fail("Quantized model is bigger than initial ONNX model")
78
79

    def _test_export(self, model, framework, opset, tokenizer=None):
80
81
        try:
            # Compute path
82
            with TemporaryDirectory() as tempdir:
83
                path = Path(tempdir).joinpath("model.onnx")
84
85

            # Remove folder if exists
86
87
            if path.parent.exists():
                path.parent.rmdir()
88

89
90
            # Export
            convert(framework, model, path, opset, tokenizer)
91

92
            return path
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        except Exception as e:
            self.fail(e)

    @require_torch
    def test_infer_dynamic_axis_pytorch(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import BertModel

        model = BertModel(BertConfig.from_pretrained("bert-base-cased"))
        tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
        self._test_infer_dynamic_axis(model, tokenizer, "pt")

    @require_tf
    def test_infer_dynamic_axis_tf(self):
        """
        Validate the dynamic axis generated for each parameters are correct
        """
        from transformers import TFBertModel

        model = TFBertModel(BertConfig.from_pretrained("bert-base-cased"))
        tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
        self._test_infer_dynamic_axis(model, tokenizer, "tf")

    def _test_infer_dynamic_axis(self, model, tokenizer, framework):
        nlp = FeatureExtractionPipeline(model, tokenizer)

        variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
        input_vars, output_vars, shapes, tokens = infer_shapes(nlp, framework)

        # Assert all variables are present
        self.assertEqual(len(shapes), len(variable_names))
        self.assertTrue(all([var_name in shapes for var_name in variable_names]))
        self.assertSequenceEqual(variable_names[:3], input_vars)
        self.assertSequenceEqual(variable_names[3:], output_vars)

        # Assert inputs are {0: batch, 1: sequence}
        for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
            self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})

        # Assert outputs are {0: batch, 1: sequence} and {0: batch}
        self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
        self.assertDictEqual(shapes["output_1"], {0: "batch"})

    def test_ensure_valid_input(self):
        """
        Validate parameters are correctly exported
        GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
        ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
        token_type_ids and attention_mask for now to not having a None tensor in the middle
        """
        # All generated args are valid
        input_names = ["input_ids", "attention_mask", "token_type_ids"]
        tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
148
        ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
149
150
151
152

        # Should have exactly the same number of args (all are valid)
        self.assertEqual(len(inputs_args), 3)

153
154
155
        # Should have exactly the same input names
        self.assertEqual(set(ordered_input_names), set(input_names))

156
157
158
159
160
        # Parameter should be reordered according to their respective place in the function:
        # (input_ids, token_type_ids, attention_mask)
        self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))

        # Generated args are interleaved with another args (for instance parameter "past" in GPT2)
161
        ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
162
163
164

        # Should have exactly the one arg (all before the one not provided "some_other_args")
        self.assertEqual(len(inputs_args), 1)
165
        self.assertEqual(len(ordered_input_names), 1)
166
167
168

        # Should have only "input_ids"
        self.assertEqual(inputs_args[0], tokens["input_ids"])
169
        self.assertEqual(ordered_input_names[0], "input_ids")
170
171
172
173

    def test_generate_identified_name(self):
        generated = generate_identified_filename(Path("/home/something/my_fake_model.onnx"), "-test")
        self.assertEqual("/home/something/my_fake_model-test.onnx", generated.as_posix())