"docs/vscode:/vscode.git/clone" did not exist on "56bd7e67c2e01122cc93d98f5bd114f9312a5cce"
Unverified Commit 6c002853 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Added capability to quantize a model while exporting through ONNX. (#6089)



* Added capability to quantize a model while exporting through ONNX.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

We do not support multiple extensions
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Reformat files
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* More quality
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Ensure test_generate_identified_name compares the same object types
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Added documentation everywhere on ONNX exporter
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use pathlib.Path instead of plain-old string
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use f-string everywhere
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use the correct parameters for black formatting
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use Python 3 super() style.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Use packaging.version to ensure installed onnxruntime version match requirements
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fixing imports sorting order.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Missing raise(s)
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Added quantization documentation
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fix some spelling.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fix bad list header format
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 25de74cc
......@@ -21,9 +21,10 @@ The following command shows how easy it is to export a BERT model from the libra
python convert_graph_to_onnx.py --framework <pt, tf> --model bert-base-cased bert-base-cased.onnx
The conversion tool works for both PyTorch and Tensorflow models and ensures:
* The model and its weights are correctly initialized from the Hugging Face model hub or a local checkpoint.
* The inputs and outputs are correctly generated to their ONNX counterpart.
* The generated model can be correctly loaded through onnxruntime.
* The model and its weights are correctly initialized from the Hugging Face model hub or a local checkpoint.
* The inputs and outputs are correctly generated to their ONNX counterpart.
* The generated model can be correctly loaded through onnxruntime.
.. note::
Currently, inputs and outputs are always exported with dynamic sequence axes preventing some optimizations
......@@ -32,9 +33,57 @@ The conversion tool works for both PyTorch and Tensorflow models and ensures:
Also, the conversion tool supports different options which let you tune the behavior of the generated model:
* Change the target opset version of the generated model: More recent opset generally supports more operator and enables faster inference.
* Export pipeline specific prediction heads: Allow to export model along with its task-specific prediction head(s).
* Use the external data format (PyTorch only): Lets you export model which size is above 2Gb (`More info <https://github.com/pytorch/pytorch/pull/33062>`_).
* Change the target opset version of the generated model: More recent opset generally supports more operator and enables faster inference.
* Export pipeline specific prediction heads: Allow to export model along with its task-specific prediction head(s).
* Use the external data format (PyTorch only): Lets you export model which size is above 2Gb (`More info <https://github.com/pytorch/pytorch/pull/33062>`_).
Quantization
------------------------------------------------
ONNX exporter supports generating a quantized version of the model to allow efficient inference.
Quantization works by converting the memory representation of the parameters in the neural network
to a compact integer format. By default, weights of a neural network are stored as single-precision float (`float32`)
which can express a wide-range of floating-point numbers with decent precision.
These properties are especially interesting at training where you want fine-grained representation.
On the other hand, after the training phase, it has been shown one can greatly reduce the range and the precision of `float32` numbers
without changing the performances of the neural network.
More technically, `float32` parameters are converted to a type requiring fewer bits to represent each number, thus reducing
the overall size of the model. Here, we are enabling `float32` mapping to `int8` values (a non-floating, single byte, number representation)
according to the following formula:
.. math::
y_{float32} = scale * x_{int8} - zero\_point
.. note::
The quantization process will infer the parameter `scale` and `zero_point` from the neural network parameters
Leveraging tiny-integers has numerous advantages when it comes to inference:
* Storing fewer bits instead of 32 bits for the `float32` reduces the size of the model and makes it load faster.
* Integer operations execute a magnitude faster on modern hardware
* Integer operations require less power to do the computations
In order to convert a transformers model to ONNX IR with quantized weights you just need to specify ``--quantize``
when using ``convert_graph_to_onnx.py``. Also, you can have a look at the ``quantize()`` utility-method in this
same script file.
Example of quantized BERT model export:
.. code-block:: bash
python convert_graph_to_onnx.py --framework <pt, tf> --model bert-base-cased --quantize bert-base-cased.onnx
.. note::
Quantization support requires ONNX Runtime >= 1.4.0
.. note::
When exporting quantized model you will end up with two different ONNX files. The one specified at the end of the
above command will contain the original ONNX model storing `float32` weights.
The second one, with ``-quantized`` suffix, will hold the quantized parameters.
TorchScript
......
from argparse import ArgumentParser
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from packaging.version import parse
from transformers import is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
# This is the minimal required version to
# support some ONNX Runtime features
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
SUPPORTED_PIPELINES = [
"feature-extraction",
"ner",
......@@ -28,18 +35,71 @@ class OnnxConverterArgumentParser(ArgumentParser):
"""
def __init__(self):
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
super().__init__("ONNX Converter")
self.add_argument("--pipeline", type=str, choices=SUPPORTED_PIPELINES, default="feature-extraction")
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
self.add_argument(
"--pipeline", type=str, choices=SUPPORTED_PIPELINES, default="feature-extraction",
)
self.add_argument(
"--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)",
)
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
self.add_argument(
"--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model",
)
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb")
self.add_argument(
"--check-loading", action="store_true", help="Check ONNX is able to load the model",
)
self.add_argument(
"--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb",
)
self.add_argument(
"--quantize", action="store_true", help="Quantize the neural network to be run with int8",
)
self.add_argument("output")
def generate_identified_filename(filename: Path, identifier: str) -> Path:
"""
Append a string-identifier at the end (before the extension, if any) to the provided filepath.
Args:
filename: pathlib.Path The actual path object we would like to add an identifier suffix
identifier: The suffix to add
Returns: String with concatenated indentifier at the end of the filename
"""
return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
def ensure_onnxruntime_installed():
"""
Check onnxruntime is installed and if the installed version match is recent enough.
Raises:
ImportError: If onnxruntime is not installed or too old version is found
"""
try:
import onnxruntime
# Parse the version of the installed onnxruntime
ort_version = parse(onnxruntime.__version__)
# We require 1.4.0 minimum
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
raise ImportError(
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
f"but we require onnxruntime to be >= 1.4.0 to enable all the conversions options.\n"
f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
)
except ImportError:
raise ImportError(
"onnxruntime doesn't seem to be currently installed. "
"Please install the onnxruntime by running `pip install onnxruntime`"
" and relaunch the conversion."
)
def ensure_valid_input(model, tokens, input_names):
"""
Ensure input are presented in the correct order, without any None
......@@ -60,7 +120,7 @@ def ensure_valid_input(model, tokens, input_names):
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
print("{} is not present in the generated input list.".format(arg_name))
print(f"{arg_name} is not present in the generated input list.")
break
print("Generated inputs order: {}".format(ordered_input_names))
......@@ -68,6 +128,19 @@ def ensure_valid_input(model, tokens, input_names):
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
"""
Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model.
Args:
nlp: The pipeline object holding the model to be exported
framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
Returns:
- List of the inferred input variable names
- List of the inferred output variable names
- Dictionary with input/output variables names as key and shape tensor as value
- a BatchEncoding reference which was used to infer all the above information
"""
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
......@@ -79,12 +152,12 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
if len(tensor.shape) == 2:
axes[1] = "sequence"
else:
raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
else:
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
print("Found {} {} with shape: {}".format("input" if is_input else "output", name, axes))
print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
return axes
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
......@@ -108,7 +181,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
outputs_flat.append(output)
# Generate output names & axes
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
output_names = [f"output_{i}" for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# Create the aggregated axes representation
......@@ -117,6 +190,17 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
"""
Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model)
Args:
pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
framework: The actual model to convert the pipeline from ("pt" or "tf")
model: The model name which will be loaded by the pipeline
tokenizer: The tokenizer name which will be loaded by the pipeline, defaut to the model's value
Returns: Pipeline object
"""
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
......@@ -127,20 +211,31 @@ def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokeniz
if framework == "tf" and not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
# Allocate tokenizer and model
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR)
Args:
nlp: The pipeline to be exported
opset: The actual version of the ONNX operator set to use
output: Path where will be stored the generated ONNX model
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
Returns:
"""
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
print("Using framework PyTorch: {}".format(torch.__version__))
print(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
......@@ -149,7 +244,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
export(
nlp.model,
model_args,
f=output,
f=output.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
......@@ -160,7 +255,17 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
"""
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
Args:
nlp: The pipeline to be exported
opset: The actual version of the ONNX operator set to use
output: Path where will be stored the generated ONNX model
Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
"""
if not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
......@@ -170,7 +275,7 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
import tensorflow as tf
from keras2onnx import convert_keras, save_model, __version__ as k2ov
print("Using framework TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
......@@ -178,34 +283,45 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
# Forward
nlp.model.predict(tokens.data)
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
save_model(onnx_model, output)
save_model(onnx_model, output.as_posix())
except ImportError as e:
raise Exception(
"Cannot import {} required to convert TF model to ONNX. Please install {} first.".format(e.name, e.name)
)
raise Exception(f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.")
def convert(
framework: str,
model: str,
output: str,
output: Path,
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
):
print("ONNX opset version set to: {}".format(opset))
"""
Convert the pipeline object to the ONNX Intermediate Representation (IR) format.
Args:
framework: The framework the pipeline is backed by ("pt" or "tf")
model: The name of the model to load for the pipeline
output: The path where the ONNX graph will be stored
opset: The actual version of the ONNX operator set to use
tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
Returns:
"""
print(f"ONNX opset version set to: {opset}")
# Load the pipeline
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
print("Creating folder {}".format(parent))
makedirs(parent)
elif len(listdir(parent)) > 0:
raise Exception("Folder {} is not empty, aborting conversion".format(parent))
if not output.parent.exists():
print(f"Creating folder {output.parent}")
makedirs(output.parent.as_posix())
elif len(listdir(output.parent.as_posix())) > 0:
raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
# Export the graph
if framework == "pt":
......@@ -214,17 +330,52 @@ def convert(
convert_tensorflow(nlp, opset, output)
def verify(path: str):
def quantize(onnx_model_path: Path) -> Path:
"""
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU.
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
try:
ensure_onnxruntime_installed()
import onnx
from onnxruntime import __version__ as ort_version
from onnxruntime.quantization import quantize, QuantizationMode
print(f"Found ONNX: {onnx.__version__}")
print(f"Found ONNXRuntime: {ort_version}")
onnx_model = onnx.load(onnx_model_path.as_posix())
quantized_model = quantize(
model=onnx_model, quantization_mode=QuantizationMode.IntegerOps, force_fusions=True, symmetric_weight=True,
)
# Append "-quantized" at the end of the model's name
quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
# Save model
print(f"Storing quantized model at {quantized_model_path}")
onnx.save(quantized_model, quantized_model_path.as_posix())
return quantized_model_path
except ImportError as ie:
print(f"Error while quantizing the model:\n{str(ie)}")
def verify(path: Path):
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
print("Checking ONNX model loading from: {}".format(path))
print(f"Checking ONNX model loading from: {path}")
try:
onnx_options = SessionOptions()
_ = InferenceSession(path, onnx_options, providers=["CPUExecutionProvider"])
print("Model correctly loaded")
_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
print(f"Model {path} correctly loaded: \N{heavy check mark}")
except RuntimeException as re:
print("Error while loading the model: {}".format(re))
print(f"Error while loading the model {re}: \N{heavy ballot x}")
if __name__ == "__main__":
......@@ -232,7 +383,7 @@ if __name__ == "__main__":
args = parser.parse_args()
# Make sure output is absolute path
args.output = abspath(args.output)
args.output = Path(args.output).absolute()
try:
# Convert
......@@ -246,9 +397,16 @@ if __name__ == "__main__":
args.pipeline,
)
if args.quantize:
args.quantized_output = quantize(args.output)
# And verify
if args.check_loading:
verify(args.output)
if hasattr(args, "quantized_output"):
verify(args.quantized_output)
except Exception as e:
print("Error while converting the model: {}".format(e))
print(f"Error while converting the model: {e}")
exit(1)
import unittest
from os.path import dirname, exists
from pathlib import Path
from shutil import rmtree
from tempfile import NamedTemporaryFile, TemporaryDirectory
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
from transformers.convert_graph_to_onnx import convert, ensure_valid_input, infer_shapes
from transformers.convert_graph_to_onnx import (
convert,
ensure_valid_input,
generate_identified_filename,
infer_shapes,
quantize,
)
from transformers.testing_utils import require_tf, require_torch, slow
......@@ -25,13 +32,13 @@ class OnnxExportTestCase(unittest.TestCase):
@slow
def test_export_tensorflow(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 11)
self._test_export(model, "tf", 12)
@require_torch
@slow
def test_export_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 11)
self._test_export(model, "pt", 12)
@require_torch
@slow
......@@ -47,7 +54,29 @@ class OnnxExportTestCase(unittest.TestCase):
with TemporaryDirectory() as bert_save_dir:
model = BertModel(BertConfig(vocab_size=len(vocab)))
model.save_pretrained(bert_save_dir)
self._test_export(bert_save_dir, "pt", 11, tokenizer)
self._test_export(bert_save_dir, "pt", 12, tokenizer)
@require_tf
@slow
def test_quantize_tf(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "tf", 12)
quantized_path = quantize(Path(path))
# Ensure the actual quantized model is not bigger than the original one
if quantized_path.stat().st_size >= Path(path).stat().st_size:
self.fail("Quantized model is bigger than initial ONNX model")
@require_torch
@slow
def test_quantize_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12)
quantized_path = quantize(Path(path))
# Ensure the actual quantized model is not bigger than the original one
if quantized_path.stat().st_size >= Path(path).stat().st_size:
self.fail("Quantized model is bigger than initial ONNX model")
def _test_export(self, model, framework, opset, tokenizer=None):
try:
......@@ -61,6 +90,8 @@ class OnnxExportTestCase(unittest.TestCase):
# Export
convert(framework, model, path, opset, tokenizer)
return path
except Exception as e:
self.fail(e)
......@@ -138,3 +169,7 @@ class OnnxExportTestCase(unittest.TestCase):
# Should have only "input_ids"
self.assertEqual(inputs_args[0], tokens["input_ids"])
self.assertEqual(ordered_input_names[0], "input_ids")
def test_generate_identified_name(self):
generated = generate_identified_filename(Path("/home/something/my_fake_model.onnx"), "-test")
self.assertEqual("/home/something/my_fake_model-test.onnx", generated.as_posix())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment