Unverified Commit 2e5e058c authored by David Fan's avatar David Fan Committed by GitHub
Browse files

Adapt to new pytorch onnx exporter API for dictionary (#3160)

* Adapt to new torch export API for dictionary
parent 0963ff71
...@@ -157,7 +157,7 @@ jobs: ...@@ -157,7 +157,7 @@ jobs:
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --editable .
pip install --user onnx pip install --user onnx
pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 pip install --user onnxruntime
python test/test_onnx.py python test/test_onnx.py
binary_linux_wheel: binary_linux_wheel:
......
...@@ -157,7 +157,7 @@ jobs: ...@@ -157,7 +157,7 @@ jobs:
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --editable .
pip install --user onnx pip install --user onnx
pip install --user -i https://test.pypi.org/simple/ ort-nightly==1.5.2.dev202012031 pip install --user onnxruntime
python test/test_onnx.py python test/test_onnx.py
binary_linux_wheel: binary_linux_wheel:
......
...@@ -33,8 +33,12 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -33,8 +33,12 @@ class ONNXExporterTester(unittest.TestCase):
model.eval() model.eval()
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
if isinstance(inputs_list[0][-1], dict):
torch_onnx_input = inputs_list[0] + ({},)
else:
torch_onnx_input = inputs_list[0]
# export to onnx with the first input # export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io, torch.onnx.export(model, torch_onnx_input, onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment