Unverified Commit 410e26c7 authored by lewtun's avatar lewtun Committed by GitHub
Browse files

Fix (deprecated) ONNX exporter to account for new tf2onnx API (#15856)

* Fix (deprecated) ONNX exporter to account for new tf2onnx API
parent e3342edc
...@@ -327,8 +327,8 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): ...@@ -327,8 +327,8 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
try: try:
import tensorflow as tf import tensorflow as tf
import tf2onnx
from tf2onnx import __version__ as t2ov from tf2onnx import __version__ as t2ov
from tf2onnx import convert_keras, save_model
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}") print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
...@@ -337,11 +337,15 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): ...@@ -337,11 +337,15 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
# Forward # Forward
nlp.model.predict(tokens.data) nlp.model.predict(tokens.data)
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset) input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
save_model(onnx_model, output.as_posix()) model_proto, _ = tf2onnx.convert.from_keras(
nlp.model, input_signature, opset=opset, output_path=output.as_posix()
)
except ImportError as e: except ImportError as e:
raise Exception(f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.") raise Exception(
f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
)
def convert( def convert(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment