Unverified Commit cb7ed6e0 authored by Alberto Bégué's avatar Alberto Bégué Committed by GitHub
Browse files

Add Tensorflow handling of ONNX conversion (#13831)



* Add TensorFlow support for ONNX export

* Change documentation to mention conversion with Tensorflow

* Refactor export into export_pytorch and export_tensorflow

* Check model's type instead of framework installation to choose between TF and Pytorch
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarAlberto Bégué <alberto.begue@della.ai>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent e923917c
......@@ -62,10 +62,6 @@ Ready-made configurations include the following architectures:
- XLM-RoBERTa
- XLM-RoBERTa-XL
The ONNX conversion is supported for the PyTorch versions of the models. If you
would like to be able to convert a TensorFlow model, please let us know by
opening an issue.
In the next two sections, we'll show you how to:
* Export a supported model using the `transformers.onnx` package.
......@@ -150,6 +146,8 @@ DistilBERT we have:
["last_hidden_state"]
```
The approach is similar for TensorFlow models.
### Selecting features for different model topologies
Each ready-made configuration comes with a set of _features_ that enable you to
......
......@@ -21,7 +21,7 @@ import numpy as np
from packaging.version import Version, parse
from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from transformers.file_utils import is_torch_onnx_dict_inputs_support_available
from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
from transformers.onnx.config import OnnxConfig
from transformers.utils import logging
......@@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version):
)
def export(
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
def export_pytorch(
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Export a PyTorch model to an ONNX Intermediate Representation (IR)
Args:
tokenizer:
model:
config:
opset:
output:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if issubclass(type(model), PreTrainedModel):
import torch
from torch.onnx import export as onnx_export
logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
onnx_export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops()
return matched_inputs, onnx_outputs
def export_tensorflow(
tokenizer: PreTrainedTokenizer,
model: TFPreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
if not is_torch_available():
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
from ..file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
if not inputs_match:
raise ValueError("Model and config inputs doesn't match")
config.patch_ops()
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
Export a TensorFlow model to an ONNX Intermediate Representation (IR)
Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
import tensorflow as tf
import onnx
import tf2onnx
model.config.return_dict = True
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
config.restore_ops()
# Ensure inputs match
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(onnx_model, output.as_posix())
config.restore_ops()
return matched_inputs, onnx_outputs
def export(
tokenizer: PreTrainedTokenizer,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if not (is_torch_available() or is_tf_available()):
raise ImportError(
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
"Please install torch or tensorflow first."
)
if is_torch_available():
from transformers.file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
if is_torch_available() and issubclass(type(model), PreTrainedModel):
return export_pytorch(tokenizer, model, config, opset, output)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(tokenizer, model, config, opset, output)
def validate_model_outputs(
config: OnnxConfig,
tokenizer: PreTrainedTokenizer,
......@@ -160,7 +260,10 @@ def validate_model_outputs(
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
else:
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
# Create ONNX Runtime session
options = SessionOptions()
......@@ -210,7 +313,10 @@ def validate_model_outputs(
# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].detach().numpy()
if issubclass(type(reference_model), PreTrainedModel):
ref_value = ref_outputs_dict[name].detach().numpy()
else:
ref_value = ref_outputs_dict[name].numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape
......@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match(
:param model_inputs: :param config_inputs: :return:
"""
forward_parameters = signature(model.forward).parameters
if issubclass(type(model), PreTrainedModel):
forward_parameters = signature(model.forward).parameters
else:
forward_parameters = signature(model.call).parameters
model_inputs_set = set(model_inputs)
# We are fine if config_inputs has more keys than model_inputs
......
from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type, Union
from .. import PretrainedConfig, is_torch_available
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
......@@ -24,7 +24,6 @@ from .config import OnnxConfig
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_torch_available():
from transformers import PreTrainedModel
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
......@@ -35,9 +34,20 @@ if is_torch_available():
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
elif is_tf_available():
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
else:
logger.warning(
"The ONNX export features are only supported for PyTorch, you will not be able to export models without it."
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
)
......@@ -80,6 +90,17 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
}
elif is_tf_available():
_TASKS_TO_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM,
"seq2seq-lm": TFAutoModelForSeq2SeqLM,
"sequence-classification": TFAutoModelForSequenceClassification,
"token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
}
else:
_TASKS_TO_AUTOMODELS = {}
......@@ -270,7 +291,7 @@ class FeaturesManager:
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel:
def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
......@@ -286,7 +307,9 @@ class FeaturesManager:
return model_class.from_pretrained(model)
@staticmethod
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
def check_supported_model_or_raise(
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
) -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features.
......
......@@ -4,7 +4,7 @@ from unittest import TestCase
from unittest.mock import patch
from parameterized import parameterized
from transformers import AutoConfig, AutoTokenizer, is_torch_available
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
......@@ -15,11 +15,11 @@ from transformers.onnx import (
from transformers.onnx.config import OnnxConfigWithPast
if is_torch_available():
if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
@require_onnx
......@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
}
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("bart", "facebook/bart-base"),
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
}
def _get_models_to_test(export_models_list):
models_to_test = []
if not is_torch_available():
# Returning some dummy test that should not be ever called because of the @require_torch decorator.
if is_torch_available() or is_tf_available():
for (name, model) in export_models_list:
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
name
).items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
else:
# Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
# decorators.
# The reason for not returning an empty list is because parameterized.expand complains when it's empty.
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
for (name, model) in export_models_list:
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
name
).items():
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
return sorted(models_to_test)
class OnnxExportTestCaseV2(TestCase):
......@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
from transformers.onnx import export
tokenizer = AutoTokenizer.from_pretrained(model_name)
......@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase):
@slow
@require_torch
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@slow
......@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
@slow
@require_tf
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
@slow
@require_tf
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
@slow
@require_tf
def test_tensorflow_export_seq2seq_with_past(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
......@@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False):
current_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
end_prompt="The ONNX conversion is supported for the PyTorch versions of the models.",
end_prompt="In the next two sections, we'll show you how to:",
)
new_list = get_onnx_model_list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment