Unverified Commit b4ce313e authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Prepare ONNX export for torch v1.11 (#15270)

* Prepare ONNX export for torch v1.11
parent 126bddd1
...@@ -112,19 +112,34 @@ def export( ...@@ -112,19 +112,34 @@ def export(
config.patch_ops() config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
export( # so we check the torch version for backwards compatibility
model, if parse(torch.__version__) <= parse("1.10.99"):
(model_inputs,), # export can work with named args but the dict containing named args
f=output.as_posix(), # has to be the last element of the args tuple.
input_names=list(config.inputs.keys()), export(
output_names=onnx_outputs, model,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, (model_inputs,),
do_constant_folding=True, f=output.as_posix(),
use_external_data_format=config.use_external_data_format(model.num_parameters()), input_names=list(config.inputs.keys()),
enable_onnx_checker=True, output_names=onnx_outputs,
opset_version=opset, dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
) do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops() config.restore_ops()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment