Unverified Commit 7231f7b5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Enable ONNX/ONNXRuntime optimizations through converter script (#6131)



* Add onnxruntime transformers optimization support
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Added Optimization section in ONNX/ONNXRuntime documentation.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Improve note reference
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fixing imports order.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Add warning about different level of optimization between torch and tf export.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Address @LysandreJik wording suggestion
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Address @LysandreJik wording suggestion
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Always optimize model before quantization for maximum performances.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Address comments on the documentation.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Improve TensorFlow optimization message as suggested by @yufenglee
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Removed --optimize parameter
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Warn the user about current quantization limitation when model is larger than 2GB.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Trigger CI for last check

* Small change in print for the optimization section.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent c0b93a1c
...@@ -5,7 +5,7 @@ Exporting transformers models ...@@ -5,7 +5,7 @@ Exporting transformers models
ONNX / ONNXRuntime ONNX / ONNXRuntime
============================================== ==============================================
Projects ONNX (Open Neural Network eXchange) and ONNXRuntime (ORT) are part of an effort from leading industries in the AI field Projects `ONNX (Open Neural Network eXchange) <http://onnx.ai>`_ and `ONNXRuntime (ORT) <https://microsoft.github.io/onnxruntime/>`_ are part of an effort from leading industries in the AI field
to provide a unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety to provide a unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety
of hardware and dedicated optimizations. of hardware and dedicated optimizations.
...@@ -34,9 +34,36 @@ The conversion tool works for both PyTorch and Tensorflow models and ensures: ...@@ -34,9 +34,36 @@ The conversion tool works for both PyTorch and Tensorflow models and ensures:
Also, the conversion tool supports different options which let you tune the behavior of the generated model: Also, the conversion tool supports different options which let you tune the behavior of the generated model:
* Change the target opset version of the generated model: More recent opset generally supports more operator and enables faster inference. * **Change the target opset version of the generated model.** (More recent opset generally supports more operators and enables faster inference)
* Export pipeline specific prediction heads: Allow to export model along with its task-specific prediction head(s).
* Use the external data format (PyTorch only): Lets you export model which size is above 2Gb (`More info <https://github.com/pytorch/pytorch/pull/33062>`_). * **Export pipeline-specific prediction heads.** (Allow to export model along with its task-specific prediction head(s))
* **Use the external data format (PyTorch only).** (Lets you export model which size is above 2Gb (`More info <https://github.com/pytorch/pytorch/pull/33062>`_))
Optimizations
------------------------------------------------
ONNXRuntime includes some transformers-specific transformations to leverage optimized operations in the graph.
Below are some of the operators which can be enabled to speed up inference through ONNXRuntime (*see note below*):
* Constant folding
* Attention Layer fusing
* Skip connection LayerNormalization fusing
* FastGeLU approximation
Fortunately, you can let ONNXRuntime find all the possible optimized operators for you. Simply add ``--optimize``
when exporting your model through ``convert_graph_to_onnx.py``.
Example:
.. code-block:: bash
python convert_graph_to_onnx.py --framework <pt, tf> --model bert-base-cased --optimize bert-base-cased.onnx
.. note::
For more information about the optimizations enabled by ONNXRuntime, please have a look at the (`ONNXRuntime Github <https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers>`_)
Quantization Quantization
------------------------------------------------ ------------------------------------------------
...@@ -85,6 +112,8 @@ Example of quantized BERT model export: ...@@ -85,6 +112,8 @@ Example of quantized BERT model export:
above command will contain the original ONNX model storing `float32` weights. above command will contain the original ONNX model storing `float32` weights.
The second one, with ``-quantized`` suffix, will hold the quantized parameters. The second one, with ``-quantized`` suffix, will hold the quantized parameters.
.. note::
The quantization export gives the best performances when used in combination with ``--optimize``.
TorchScript TorchScript
======================================= =======================================
......
...@@ -3,7 +3,7 @@ from os import listdir, makedirs ...@@ -3,7 +3,7 @@ from os import listdir, makedirs
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from packaging.version import parse from packaging.version import Version, parse
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput from transformers.file_utils import ModelOutput
...@@ -72,7 +72,7 @@ def generate_identified_filename(filename: Path, identifier: str) -> Path: ...@@ -72,7 +72,7 @@ def generate_identified_filename(filename: Path, identifier: str) -> Path:
return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix) return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
def ensure_onnxruntime_installed(): def check_onnxruntime_requirements(minimum_version: Version):
""" """
Check onnxruntime is installed and if the installed version match is recent enough. Check onnxruntime is installed and if the installed version match is recent enough.
Raises: Raises:
...@@ -88,7 +88,7 @@ def ensure_onnxruntime_installed(): ...@@ -88,7 +88,7 @@ def ensure_onnxruntime_installed():
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION: if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
raise ImportError( raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) " f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= 1.4.0 to enable all the conversions options.\n" f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
f"Please update onnxruntime by running `pip install --upgrade onnxruntime`" f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
) )
...@@ -330,6 +330,30 @@ def convert( ...@@ -330,6 +330,30 @@ def convert(
convert_tensorflow(nlp, opset, output) convert_tensorflow(nlp, opset, output)
def optimize(onnx_model_path: Path) -> Path:
"""
Load the model at the specified path and let onnxruntime look at transformations on the graph
to enable all the optimizations possible
Args:
onnx_model_path: filepath where the model binary description is stored
Returns: Path where the optimized model binary description has been saved
"""
from onnxruntime import SessionOptions, InferenceSession
# Generate model name with suffix "optimized"
opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
sess_option = SessionOptions()
sess_option.optimized_model_filepath = opt_model_path.as_posix()
_ = InferenceSession(onnx_model_path.as_posix(), sess_option)
print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
return opt_model_path
def quantize(onnx_model_path: Path) -> Path: def quantize(onnx_model_path: Path) -> Path:
""" """
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU. Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU.
...@@ -338,17 +362,18 @@ def quantize(onnx_model_path: Path) -> Path: ...@@ -338,17 +362,18 @@ def quantize(onnx_model_path: Path) -> Path:
Returns: The Path generated for the quantized Returns: The Path generated for the quantized
""" """
try: try:
ensure_onnxruntime_installed()
import onnx import onnx
from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import quantize, QuantizationMode from onnxruntime.quantization import quantize, QuantizationMode
print(f"Found ONNX: {onnx.__version__}")
print(f"Found ONNXRuntime: {ort_version}")
onnx_model = onnx.load(onnx_model_path.as_posix()) onnx_model = onnx.load(onnx_model_path.as_posix())
# Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime
print(
"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\n"
"This limitation will be removed in the next release of onnxruntime."
)
quantized_model = quantize( quantized_model = quantize(
model=onnx_model, quantization_mode=QuantizationMode.IntegerOps, force_fusions=True, symmetric_weight=True, model=onnx_model, quantization_mode=QuantizationMode.IntegerOps, force_fusions=True, symmetric_weight=True,
) )
...@@ -357,11 +382,11 @@ def quantize(onnx_model_path: Path) -> Path: ...@@ -357,11 +382,11 @@ def quantize(onnx_model_path: Path) -> Path:
quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized") quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
# Save model # Save model
print(f"Storing quantized model at {quantized_model_path}") print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
onnx.save(quantized_model, quantized_model_path.as_posix()) onnx.save_model(quantized_model, quantized_model_path.as_posix())
return quantized_model_path return quantized_model_path
except ImportError as ie: except Exception as ie:
print(f"Error while quantizing the model:\n{str(ie)}") print(f"Error while quantizing the model:\n{str(ie)}")
...@@ -369,7 +394,7 @@ def verify(path: Path): ...@@ -369,7 +394,7 @@ def verify(path: Path):
from onnxruntime import InferenceSession, SessionOptions from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
print(f"Checking ONNX model loading from: {path}") print(f"Checking ONNX model loading from: {path} ...")
try: try:
onnx_options = SessionOptions() onnx_options = SessionOptions()
_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"]) _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
...@@ -386,6 +411,7 @@ if __name__ == "__main__": ...@@ -386,6 +411,7 @@ if __name__ == "__main__":
args.output = Path(args.output).absolute() args.output = Path(args.output).absolute()
try: try:
print("\n====== Converting model to ONNX ======")
# Convert # Convert
convert( convert(
args.framework, args.framework,
...@@ -398,12 +424,34 @@ if __name__ == "__main__": ...@@ -398,12 +424,34 @@ if __name__ == "__main__":
) )
if args.quantize: if args.quantize:
args.quantized_output = quantize(args.output) # Ensure requirements for quantization on onnxruntime is met
check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
# onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
if args.framework == "tf":
print(
"\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
"\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
"\t For more information, please refer to the onnxruntime documentation:\n"
"\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
)
print("\n====== Optimizing ONNX model ======")
# Quantization works best when using the optimized version of the model
args.optimized_output = optimize(args.output)
# Do the quantization on the right graph
args.quantized_output = quantize(args.optimized_output)
# And verify # And verify
if args.check_loading: if args.check_loading:
print("\n====== Check exported ONNX model(s) ======")
verify(args.output) verify(args.output)
if hasattr(args, "optimized_output"):
verify(args.optimized_output)
if hasattr(args, "quantized_output"): if hasattr(args, "quantized_output"):
verify(args.quantized_output) verify(args.quantized_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment