"sims/nic/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "fd1211c7d3203843f5ac328f6ee905f51ac022b6"
Unverified Commit b4ce313e authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Prepare ONNX export for torch v1.11 (#15270)

* Prepare ONNX export for torch v1.11
parent 126bddd1
......@@ -112,7 +112,11 @@ def export(
config.patch_ops()
# export can works with named args but the dict containing named args as to be last element of the args tuple
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
......@@ -125,6 +129,17 @@ def export(
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
config.restore_ops()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment