Unverified Commit a26d71d6 authored by Dean Wyatte's avatar Dean Wyatte Committed by GitHub
Browse files

Export TensorFlow models to ONNX with dynamic input shapes (#19255)

* validate onnx models with a different input geometry than saved with

* only test working features for now

* simpler test skipping

* rm TODO

* expose batch_size/seq_length on vit

* skip certain name, feature, framework parameterizations known to fail validation

* Trigger CI

* Trigger CI
parent 5fef17f4
......@@ -355,11 +355,17 @@ class CLIPOnnxConfig(OnnxConfig):
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
batch_size: int = -1,
seq_length: int = -1,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
text_input_dict = super().generate_dummy_inputs(
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
)
image_input_dict = super().generate_dummy_inputs(
processor.feature_extractor, batch_size=batch_size, framework=framework
)
return {**text_input_dict, **image_input_dict}
@property
......
......@@ -381,11 +381,17 @@ class GroupViTOnnxConfig(OnnxConfig):
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
batch_size: int = -1,
seq_length: int = -1,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
text_input_dict = super().generate_dummy_inputs(
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
)
image_input_dict = super().generate_dummy_inputs(
processor.feature_extractor, batch_size=batch_size, framework=framework
)
return {**text_input_dict, **image_input_dict}
@property
......
......@@ -372,11 +372,17 @@ class OwlViTOnnxConfig(OnnxConfig):
def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
batch_size: int = -1,
seq_length: int = -1,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
text_input_dict = super().generate_dummy_inputs(
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
)
image_input_dict = super().generate_dummy_inputs(
processor.feature_extractor, batch_size=batch_size, framework=framework
)
return {**text_input_dict, **image_input_dict}
@property
......
......@@ -262,7 +262,9 @@ def export_tensorflow(
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
input_signature = [
tf.TensorSpec([None] * tensor.ndim, dtype=tensor.dtype, name=key) for key, tensor in model_inputs.items()
]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(onnx_model, output.as_posix())
config.restore_ops()
......@@ -363,12 +365,22 @@ def validate_model_outputs(
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
preprocessor = tokenizer
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
reference_model_inputs = config.generate_dummy_inputs(
preprocessor,
batch_size=config.default_fixed_batch + 1,
seq_length=config.default_fixed_sequence + 1,
framework=TensorType.PYTORCH,
)
else:
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
reference_model_inputs = config.generate_dummy_inputs(
preprocessor,
batch_size=config.default_fixed_batch + 1,
seq_length=config.default_fixed_sequence + 1,
framework=TensorType.TENSORFLOW,
)
# Create ONNX Runtime session
options = SessionOptions()
......
......@@ -284,10 +284,12 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
def _onnx_export(
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu", framework="pt"
):
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework=framework)
config = AutoConfig.from_pretrained(model_name)
model = model_class.from_config(config)
......@@ -296,6 +298,22 @@ class OnnxExportTestCaseV2(TestCase):
if model.__class__.__name__.startswith("Yolos") and device != "cpu":
return
# ONNX inference fails with the following name, feature, framework parameterizations
# See: https://github.com/huggingface/transformers/issues/19357
if (name, feature, framework) in {
("deberta-v2", "question-answering", "pt"),
("deberta-v2", "multiple-choice", "pt"),
("roformer", "multiple-choice", "pt"),
("groupvit", "default", "pt"),
("perceiver", "masked-lm", "pt"),
("perceiver", "sequence-classification", "pt"),
("perceiver", "image-classification", "pt"),
("bert", "multiple-choice", "tf"),
("camembert", "multiple-choice", "tf"),
("roberta", "multiple-choice", "tf"),
}:
return
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
......@@ -364,13 +382,13 @@ class OnnxExportTestCaseV2(TestCase):
@require_tf
@require_vision
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
@slow
@require_tf
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
@slow
......@@ -378,7 +396,7 @@ class OnnxExportTestCaseV2(TestCase):
def test_tensorflow_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
class StableDropoutTestCase(TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment