Unverified Commit 243439a8 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Fix ONNX tests for ONNX Runtime v1.13.1 (#19950)



* Fix ONNX tests for ONNX Runtime v1.13.1
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0b294c23
...@@ -435,6 +435,7 @@ def quantize(onnx_model_path: Path) -> Path: ...@@ -435,6 +435,7 @@ def quantize(onnx_model_path: Path) -> Path:
Returns: The Path generated for the quantized Returns: The Path generated for the quantized
""" """
import onnx import onnx
import onnxruntime
from onnx.onnx_pb import ModelProto from onnx.onnx_pb import ModelProto
from onnxruntime.quantization import QuantizationMode from onnxruntime.quantization import QuantizationMode
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
...@@ -454,19 +455,37 @@ def quantize(onnx_model_path: Path) -> Path: ...@@ -454,19 +455,37 @@ def quantize(onnx_model_path: Path) -> Path:
copy_model.CopyFrom(onnx_model) copy_model.CopyFrom(onnx_model)
# Construct quantizer # Construct quantizer
quantizer = ONNXQuantizer( # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
model=copy_model, # check the onnxruntime version to ensure backward compatibility.
per_channel=False, # See also: https://github.com/microsoft/onnxruntime/pull/12873
reduce_range=False, if parse(onnxruntime.__version__) < parse("1.13.1"):
mode=QuantizationMode.IntegerOps, quantizer = ONNXQuantizer(
static=False, model=copy_model,
weight_qType=True, per_channel=False,
input_qType=False, reduce_range=False,
tensors_range=None, mode=QuantizationMode.IntegerOps,
nodes_to_quantize=None, static=False,
nodes_to_exclude=None, weight_qType=True,
op_types_to_quantize=list(IntegerOpsRegistry), input_qType=False,
) tensors_range=None,
nodes_to_quantize=None,
nodes_to_exclude=None,
op_types_to_quantize=list(IntegerOpsRegistry),
)
else:
quantizer = ONNXQuantizer(
model=copy_model,
per_channel=False,
reduce_range=False,
mode=QuantizationMode.IntegerOps,
static=False,
weight_qType=True,
activation_qType=False,
tensors_range=None,
nodes_to_quantize=None,
nodes_to_exclude=None,
op_types_to_quantize=list(IntegerOpsRegistry),
)
# Quantize and export # Quantize and export
quantizer.quantize_model() quantizer.quantize_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment