Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ebc4edfe
Unverified
Commit
ebc4edfe
authored
Jan 14, 2022
by
Joao Gante
Committed by
GitHub
Jan 14, 2022
Browse files
update from keras2onnx to tf2onnx (#15162)
parent
1b730c3d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
24 additions
and
24 deletions
+24
-24
setup.py
setup.py
+4
-4
src/transformers/convert_graph_to_onnx.py
src/transformers/convert_graph_to_onnx.py
+4
-4
src/transformers/dependency_versions_table.py
src/transformers/dependency_versions_table.py
+1
-1
src/transformers/file_utils.py
src/transformers/file_utils.py
+6
-6
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+4
-4
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+5
-5
No files found.
setup.py
View file @
ebc4edfe
...
@@ -114,7 +114,6 @@ _deps = [
...
@@ -114,7 +114,6 @@ _deps = [
"jax>=0.2.8"
,
"jax>=0.2.8"
,
"jaxlib>=0.1.65"
,
"jaxlib>=0.1.65"
,
"jieba"
,
"jieba"
,
"keras2onnx"
,
"nltk"
,
"nltk"
,
"numpy>=1.17"
,
"numpy>=1.17"
,
"onnxconverter-common"
,
"onnxconverter-common"
,
...
@@ -147,6 +146,7 @@ _deps = [
...
@@ -147,6 +146,7 @@ _deps = [
"starlette"
,
"starlette"
,
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu>=2.3"
,
"tensorflow>=2.3"
,
"tensorflow>=2.3"
,
"tf2onnx"
,
"timeout-decorator"
,
"timeout-decorator"
,
"timm"
,
"timm"
,
"tokenizers>=0.10.1"
,
"tokenizers>=0.10.1"
,
...
@@ -229,8 +229,8 @@ extras = {}
...
@@ -229,8 +229,8 @@ extras = {}
extras
[
"ja"
]
=
deps_list
(
"fugashi"
,
"ipadic"
,
"unidic_lite"
,
"unidic"
)
extras
[
"ja"
]
=
deps_list
(
"fugashi"
,
"ipadic"
,
"unidic_lite"
,
"unidic"
)
extras
[
"sklearn"
]
=
deps_list
(
"scikit-learn"
)
extras
[
"sklearn"
]
=
deps_list
(
"scikit-learn"
)
extras
[
"tf"
]
=
deps_list
(
"tensorflow"
,
"onnxconverter-common"
,
"
keras
2onnx"
)
extras
[
"tf"
]
=
deps_list
(
"tensorflow"
,
"onnxconverter-common"
,
"
tf
2onnx"
)
extras
[
"tf-cpu"
]
=
deps_list
(
"tensorflow-cpu"
,
"onnxconverter-common"
,
"
keras
2onnx"
)
extras
[
"tf-cpu"
]
=
deps_list
(
"tensorflow-cpu"
,
"onnxconverter-common"
,
"
tf
2onnx"
)
extras
[
"torch"
]
=
deps_list
(
"torch"
)
extras
[
"torch"
]
=
deps_list
(
"torch"
)
...
@@ -243,7 +243,7 @@ else:
...
@@ -243,7 +243,7 @@ else:
extras
[
"tokenizers"
]
=
deps_list
(
"tokenizers"
)
extras
[
"tokenizers"
]
=
deps_list
(
"tokenizers"
)
extras
[
"onnxruntime"
]
=
deps_list
(
"onnxruntime"
,
"onnxruntime-tools"
)
extras
[
"onnxruntime"
]
=
deps_list
(
"onnxruntime"
,
"onnxruntime-tools"
)
extras
[
"onnx"
]
=
deps_list
(
"onnxconverter-common"
,
"
keras
2onnx"
)
+
extras
[
"onnxruntime"
]
extras
[
"onnx"
]
=
deps_list
(
"onnxconverter-common"
,
"
tf
2onnx"
)
+
extras
[
"onnxruntime"
]
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"modelcreation"
]
=
deps_list
(
"cookiecutter"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
extras
[
"sagemaker"
]
=
deps_list
(
"sagemaker"
)
...
...
src/transformers/convert_graph_to_onnx.py
View file @
ebc4edfe
...
@@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
...
@@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
Path
):
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
Path
):
"""
"""
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
)
Args:
Args:
nlp: The pipeline to be exported
nlp: The pipeline to be exported
...
@@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
...
@@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
try
:
try
:
import
tensorflow
as
tf
import
tensorflow
as
tf
from
keras
2onnx
import
__version__
as
k
2ov
from
tf
2onnx
import
__version__
as
t
2ov
from
keras
2onnx
import
convert_keras
,
save_model
from
tf
2onnx
import
convert_keras
,
save_model
print
(
f
"Using framework TensorFlow:
{
tf
.
version
.
VERSION
}
,
keras
2onnx:
{
k
2ov
}
"
)
print
(
f
"Using framework TensorFlow:
{
tf
.
version
.
VERSION
}
,
tf
2onnx:
{
t
2ov
}
"
)
# Build
# Build
input_names
,
output_names
,
dynamic_axes
,
tokens
=
infer_shapes
(
nlp
,
"tf"
)
input_names
,
output_names
,
dynamic_axes
,
tokens
=
infer_shapes
(
nlp
,
"tf"
)
...
...
src/transformers/dependency_versions_table.py
View file @
ebc4edfe
...
@@ -24,7 +24,6 @@ deps = {
...
@@ -24,7 +24,6 @@ deps = {
"jax"
:
"jax>=0.2.8"
,
"jax"
:
"jax>=0.2.8"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"jaxlib"
:
"jaxlib>=0.1.65"
,
"jieba"
:
"jieba"
,
"jieba"
:
"jieba"
,
"keras2onnx"
:
"keras2onnx"
,
"nltk"
:
"nltk"
,
"nltk"
:
"nltk"
,
"numpy"
:
"numpy>=1.17"
,
"numpy"
:
"numpy>=1.17"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
"onnxconverter-common"
:
"onnxconverter-common"
,
...
@@ -57,6 +56,7 @@ deps = {
...
@@ -57,6 +56,7 @@ deps = {
"starlette"
:
"starlette"
,
"starlette"
:
"starlette"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
"tensorflow"
:
"tensorflow>=2.3"
,
"tf2onnx"
:
"tf2onnx"
,
"timeout-decorator"
:
"timeout-decorator"
,
"timeout-decorator"
:
"timeout-decorator"
,
"timm"
:
"timm"
,
"timm"
:
"timm"
,
"tokenizers"
:
"tokenizers>=0.10.1"
,
"tokenizers"
:
"tokenizers>=0.10.1"
,
...
...
src/transformers/file_utils.py
View file @
ebc4edfe
...
@@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError:
...
@@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError:
_sympy_available
=
False
_sympy_available
=
False
_
keras
2onnx_available
=
importlib
.
util
.
find_spec
(
"
keras
2onnx"
)
is
not
None
_
tf
2onnx_available
=
importlib
.
util
.
find_spec
(
"
tf
2onnx"
)
is
not
None
try
:
try
:
_
keras
2onnx_version
=
importlib_metadata
.
version
(
"
keras
2onnx"
)
_
tf
2onnx_version
=
importlib_metadata
.
version
(
"
tf
2onnx"
)
logger
.
debug
(
f
"Successfully imported
keras
2onnx version
{
_
keras
2onnx_version
}
"
)
logger
.
debug
(
f
"Successfully imported
tf
2onnx version
{
_
tf
2onnx_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
except
importlib_metadata
.
PackageNotFoundError
:
_
keras
2onnx_available
=
False
_
tf
2onnx_available
=
False
_onnx_available
=
importlib
.
util
.
find_spec
(
"onnxruntime"
)
is
not
None
_onnx_available
=
importlib
.
util
.
find_spec
(
"onnxruntime"
)
is
not
None
try
:
try
:
...
@@ -429,8 +429,8 @@ def is_coloredlogs_available():
...
@@ -429,8 +429,8 @@ def is_coloredlogs_available():
return
_coloredlogs_available
return
_coloredlogs_available
def
is_
keras
2onnx_available
():
def
is_
tf
2onnx_available
():
return
_
keras
2onnx_available
return
_
tf
2onnx_available
def
is_onnx_available
():
def
is_onnx_available
():
...
...
src/transformers/testing_utils.py
View file @
ebc4edfe
...
@@ -35,7 +35,6 @@ from .file_utils import (
...
@@ -35,7 +35,6 @@ from .file_utils import (
is_faiss_available
,
is_faiss_available
,
is_flax_available
,
is_flax_available
,
is_ftfy_available
,
is_ftfy_available
,
is_keras2onnx_available
,
is_librosa_available
,
is_librosa_available
,
is_onnx_available
,
is_onnx_available
,
is_pandas_available
,
is_pandas_available
,
...
@@ -49,6 +48,7 @@ from .file_utils import (
...
@@ -49,6 +48,7 @@ from .file_utils import (
is_soundfile_availble
,
is_soundfile_availble
,
is_spacy_available
,
is_spacy_available
,
is_tensorflow_probability_available
,
is_tensorflow_probability_available
,
is_tf2onnx_available
,
is_tf_available
,
is_tf_available
,
is_timm_available
,
is_timm_available
,
is_tokenizers_available
,
is_tokenizers_available
,
...
@@ -246,9 +246,9 @@ def require_rjieba(test_case):
...
@@ -246,9 +246,9 @@ def require_rjieba(test_case):
return
test_case
return
test_case
def
require_
keras
2onnx
(
test_case
):
def
require_
tf
2onnx
(
test_case
):
if
not
is_
keras
2onnx_available
():
if
not
is_
tf
2onnx_available
():
return
unittest
.
skip
(
"test requires
keras
2onnx"
)(
test_case
)
return
unittest
.
skip
(
"test requires
tf
2onnx"
)(
test_case
)
else
:
else
:
return
test_case
return
test_case
...
...
tests/test_modeling_tf_common.py
View file @
ebc4edfe
...
@@ -36,8 +36,8 @@ from transformers.testing_utils import (
...
@@ -36,8 +36,8 @@ from transformers.testing_utils import (
_tf_gpu_memory_limit
,
_tf_gpu_memory_limit
,
is_pt_tf_cross_test
,
is_pt_tf_cross_test
,
is_staging_test
,
is_staging_test
,
require_keras2onnx
,
require_tf
,
require_tf
,
require_tf2onnx
,
slow
,
slow
,
)
)
from
transformers.utils
import
logging
from
transformers.utils
import
logging
...
@@ -254,14 +254,14 @@ class TFModelTesterMixin:
...
@@ -254,14 +254,14 @@ class TFModelTesterMixin:
self
.
assertEqual
(
len
(
incompatible_ops
),
0
,
incompatible_ops
)
self
.
assertEqual
(
len
(
incompatible_ops
),
0
,
incompatible_ops
)
@
require_
keras
2onnx
@
require_
tf
2onnx
@
slow
@
slow
def
test_onnx_runtime_optimize
(
self
):
def
test_onnx_runtime_optimize
(
self
):
if
not
self
.
test_onnx
:
if
not
self
.
test_onnx
:
return
return
import
keras2onnx
import
onnxruntime
import
onnxruntime
import
tf2onnx
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -269,9 +269,9 @@ class TFModelTesterMixin:
...
@@ -269,9 +269,9 @@ class TFModelTesterMixin:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
model
(
model
.
dummy_inputs
)
model
(
model
.
dummy_inputs
)
onnx_model
=
keras
2onnx
.
convert_keras
(
model
,
model
.
name
,
target_
opset
=
self
.
onnx_min_opset
)
onnx_model
_proto
,
_
=
tf
2onnx
.
convert
.
from
_keras
(
model
,
opset
=
self
.
onnx_min_opset
)
onnxruntime
.
InferenceSession
(
onnx_model
.
SerializeToString
())
onnxruntime
.
InferenceSession
(
onnx_model
_proto
.
SerializeToString
())
def
test_keras_save_load
(
self
):
def
test_keras_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment