Unverified Commit ec62b7d9 authored by Rens's avatar Rens Committed by GitHub
Browse files

Fix onnx export input names order (#4641)

* pass on tokenizer to pipeline

* order input names when convert to onnx

* update style

* remove unused imports

* make ordered inputs list needs to be mutable

* add test custom bert model

* remove unused imports
parent bf760c80
from argparse import ArgumentParser
from itertools import takewhile
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
......@@ -38,14 +37,17 @@ def ensure_valid_input(model, tokens, input_names):
"""
model_args_name = model.forward.__code__.co_varnames
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
for arg_pos, arg_name in model_args_pos:
model_args[arg_pos] = tokens[arg_name]
ordered_input_names = []
model_args = []
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
if arg_name in input_names:
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
break
model_args = tuple(model_args) # Need to be ordered
return tuple(takewhile(lambda arg: arg is not None, model_args))
return ordered_input_names, tuple(model_args)
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
......@@ -117,13 +119,13 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
model_args = ensure_valid_input(nlp.model, tokens, input_names)
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output,
input_names=input_names,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
......
import unittest
from os import sep
from os.path import dirname, exists
from shutil import rmtree
from tempfile import NamedTemporaryFile, TemporaryDirectory
from tests.utils import require_tf, require_torch, slow
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
......@@ -33,17 +33,34 @@ class OnnxExportTestCase(unittest.TestCase):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 11)
def _test_export(self, model, framework, opset):
@require_torch
@slow
def test_export_custom_bert_model(self):
from transformers import BertModel
vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
with NamedTemporaryFile(mode="w+t") as vocab_file:
vocab_file.write("\n".join(vocab))
vocab_file.flush()
tokenizer = BertTokenizerFast(vocab_file.name)
with TemporaryDirectory() as bert_save_dir:
model = BertModel(BertConfig(vocab_size=len(vocab)))
model.save_pretrained(bert_save_dir)
self._test_export(bert_save_dir, "pt", 11, tokenizer)
def _test_export(self, model, framework, opset, tokenizer=None):
try:
# Compute path
path = "onnx" + sep + model + ".onnx"
with TemporaryDirectory() as tempdir:
path = tempdir + "/model.onnx"
# Remove folder if exists
if exists(dirname(path)):
rmtree(dirname(path))
# Export
convert(framework, model, path, opset)
# Export
convert(framework, model, path, opset, tokenizer)
except Exception as e:
self.fail(e)
......@@ -99,20 +116,25 @@ class OnnxExportTestCase(unittest.TestCase):
# All generated args are valid
input_names = ["input_ids", "attention_mask", "token_type_ids"]
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
# Should have exactly the same number of args (all are valid)
self.assertEqual(len(inputs_args), 3)
# Should have exactly the same input names
self.assertEqual(set(ordered_input_names), set(input_names))
# Parameter should be reordered according to their respective place in the function:
# (input_ids, token_type_ids, attention_mask)
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
# Should have exactly the one arg (all before the one not provided "some_other_args")
self.assertEqual(len(inputs_args), 1)
self.assertEqual(len(ordered_input_names), 1)
# Should have only "input_ids"
self.assertEqual(inputs_args[0], tokens["input_ids"])
self.assertEqual(ordered_input_names[0], "input_ids")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment