Unverified Commit ebc4edfe authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

update from keras2onnx to tf2onnx (#15162)

parent 1b730c3d
...@@ -114,7 +114,6 @@ _deps = [ ...@@ -114,7 +114,6 @@ _deps = [
"jax>=0.2.8", "jax>=0.2.8",
"jaxlib>=0.1.65", "jaxlib>=0.1.65",
"jieba", "jieba",
"keras2onnx",
"nltk", "nltk",
"numpy>=1.17", "numpy>=1.17",
"onnxconverter-common", "onnxconverter-common",
...@@ -147,6 +146,7 @@ _deps = [ ...@@ -147,6 +146,7 @@ _deps = [
"starlette", "starlette",
"tensorflow-cpu>=2.3", "tensorflow-cpu>=2.3",
"tensorflow>=2.3", "tensorflow>=2.3",
"tf2onnx",
"timeout-decorator", "timeout-decorator",
"timm", "timm",
"tokenizers>=0.10.1", "tokenizers>=0.10.1",
...@@ -229,8 +229,8 @@ extras = {} ...@@ -229,8 +229,8 @@ extras = {}
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic") extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
extras["sklearn"] = deps_list("scikit-learn") extras["sklearn"] = deps_list("scikit-learn")
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx") extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx") extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch")
...@@ -243,7 +243,7 @@ else: ...@@ -243,7 +243,7 @@ else:
extras["tokenizers"] = deps_list("tokenizers") extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools") extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"] extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
extras["modelcreation"] = deps_list("cookiecutter") extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker") extras["sagemaker"] = deps_list("sagemaker")
......
...@@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
""" """
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
Args: Args:
nlp: The pipeline to be exported nlp: The pipeline to be exported
...@@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): ...@@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
try: try:
import tensorflow as tf import tensorflow as tf
from keras2onnx import __version__ as k2ov from tf2onnx import __version__ as t2ov
from keras2onnx import convert_keras, save_model from tf2onnx import convert_keras, save_model
print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}") print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
# Build # Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf") input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
......
...@@ -24,7 +24,6 @@ deps = { ...@@ -24,7 +24,6 @@ deps = {
"jax": "jax>=0.2.8", "jax": "jax>=0.2.8",
"jaxlib": "jaxlib>=0.1.65", "jaxlib": "jaxlib>=0.1.65",
"jieba": "jieba", "jieba": "jieba",
"keras2onnx": "keras2onnx",
"nltk": "nltk", "nltk": "nltk",
"numpy": "numpy>=1.17", "numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common", "onnxconverter-common": "onnxconverter-common",
...@@ -57,6 +56,7 @@ deps = { ...@@ -57,6 +56,7 @@ deps = {
"starlette": "starlette", "starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3", "tensorflow": "tensorflow>=2.3",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator", "timeout-decorator": "timeout-decorator",
"timm": "timm", "timm": "timm",
"tokenizers": "tokenizers>=0.10.1", "tokenizers": "tokenizers>=0.10.1",
......
...@@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError: ...@@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError:
_sympy_available = False _sympy_available = False
_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None _tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
try: try:
_keras2onnx_version = importlib_metadata.version("keras2onnx") _tf2onnx_version = importlib_metadata.version("tf2onnx")
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}") logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_keras2onnx_available = False _tf2onnx_available = False
_onnx_available = importlib.util.find_spec("onnxruntime") is not None _onnx_available = importlib.util.find_spec("onnxruntime") is not None
try: try:
...@@ -429,8 +429,8 @@ def is_coloredlogs_available(): ...@@ -429,8 +429,8 @@ def is_coloredlogs_available():
return _coloredlogs_available return _coloredlogs_available
def is_keras2onnx_available(): def is_tf2onnx_available():
return _keras2onnx_available return _tf2onnx_available
def is_onnx_available(): def is_onnx_available():
......
...@@ -35,7 +35,6 @@ from .file_utils import ( ...@@ -35,7 +35,6 @@ from .file_utils import (
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
is_keras2onnx_available,
is_librosa_available, is_librosa_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
...@@ -49,6 +48,7 @@ from .file_utils import ( ...@@ -49,6 +48,7 @@ from .file_utils import (
is_soundfile_availble, is_soundfile_availble,
is_spacy_available, is_spacy_available,
is_tensorflow_probability_available, is_tensorflow_probability_available,
is_tf2onnx_available,
is_tf_available, is_tf_available,
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -246,9 +246,9 @@ def require_rjieba(test_case): ...@@ -246,9 +246,9 @@ def require_rjieba(test_case):
return test_case return test_case
def require_keras2onnx(test_case): def require_tf2onnx(test_case):
if not is_keras2onnx_available(): if not is_tf2onnx_available():
return unittest.skip("test requires keras2onnx")(test_case) return unittest.skip("test requires tf2onnx")(test_case)
else: else:
return test_case return test_case
......
...@@ -36,8 +36,8 @@ from transformers.testing_utils import ( ...@@ -36,8 +36,8 @@ from transformers.testing_utils import (
_tf_gpu_memory_limit, _tf_gpu_memory_limit,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
require_keras2onnx,
require_tf, require_tf,
require_tf2onnx,
slow, slow,
) )
from transformers.utils import logging from transformers.utils import logging
...@@ -254,14 +254,14 @@ class TFModelTesterMixin: ...@@ -254,14 +254,14 @@ class TFModelTesterMixin:
self.assertEqual(len(incompatible_ops), 0, incompatible_ops) self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
@require_keras2onnx @require_tf2onnx
@slow @slow
def test_onnx_runtime_optimize(self): def test_onnx_runtime_optimize(self):
if not self.test_onnx: if not self.test_onnx:
return return
import keras2onnx
import onnxruntime import onnxruntime
import tf2onnx
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -269,9 +269,9 @@ class TFModelTesterMixin: ...@@ -269,9 +269,9 @@ class TFModelTesterMixin:
model = model_class(config) model = model_class(config)
model(model.dummy_inputs) model(model.dummy_inputs)
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset) onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
onnxruntime.InferenceSession(onnx_model.SerializeToString()) onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
def test_keras_save_load(self): def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment