Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0e1fce3c
Unverified
Commit
0e1fce3c
authored
Jun 25, 2020
by
Anthony MOI
Committed by
GitHub
Jun 25, 2020
Browse files
Fix convert_graph_to_onnx (#5230)
parent
5543efd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
7 deletions
+20
-7
src/transformers/convert_graph_to_onnx.py
src/transformers/convert_graph_to_onnx.py
+20
-7
No files found.
src/transformers/convert_graph_to_onnx.py
View file @
0e1fce3c
...
@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
...
@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
return
input_vars
,
output_names
,
dynamic_axes
,
tokens
return
input_vars
,
output_names
,
dynamic_axes
,
tokens
def
load_graph_from_args
(
framework
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
)
->
Pipeline
:
def
load_graph_from_args
(
pipeline_name
:
str
,
framework
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
)
->
Pipeline
:
# If no tokenizer provided
# If no tokenizer provided
if
tokenizer
is
None
:
if
tokenizer
is
None
:
tokenizer
=
model
tokenizer
=
model
# Check the wanted framework is available
if
framework
==
"pt"
and
not
is_torch_available
():
raise
Exception
(
"Cannot convert because PyTorch is not installed. Please install torch first."
)
if
framework
==
"tf"
and
not
is_tf_available
():
raise
Exception
(
"Cannot convert because TF is not installed. Please install tensorflow first."
)
print
(
"Loading pipeline (model: {}, tokenizer: {})"
.
format
(
model
,
tokenizer
))
print
(
"Loading pipeline (model: {}, tokenizer: {})"
.
format
(
model
,
tokenizer
))
# Allocate tokenizer and model
# Allocate tokenizer and model
return
pipeline
(
args
.
pipeline
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
framework
)
return
pipeline
(
pipeline
_name
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
framework
)
def
convert_pytorch
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
,
use_external_format
:
bool
):
def
convert_pytorch
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
,
use_external_format
:
bool
):
...
@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
...
@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
):
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
):
if
not
is_tf_available
():
if
not
is_tf_available
():
raise
Exception
(
raise
Exception
(
"Cannot convert because TF is not installed. Please install tensorflow first."
)
"Cannot convert {} because TF is not installed. Please install torch first."
.
format
(
args
.
model
)
)
print
(
"/!
\\
Please note TensorFlow doesn't support exporting model > 2Gb /!
\\
"
)
print
(
"/!
\\
Please note TensorFlow doesn't support exporting model > 2Gb /!
\\
"
)
...
@@ -187,11 +191,12 @@ def convert(
...
@@ -187,11 +191,12 @@ def convert(
opset
:
int
,
opset
:
int
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
use_external_format
:
bool
=
False
,
use_external_format
:
bool
=
False
,
pipeline_name
:
str
=
"feature-extraction"
,
):
):
print
(
"ONNX opset version set to: {}"
.
format
(
opset
))
print
(
"ONNX opset version set to: {}"
.
format
(
opset
))
# Load the pipeline
# Load the pipeline
nlp
=
load_graph_from_args
(
framework
,
model
,
tokenizer
)
nlp
=
load_graph_from_args
(
pipeline_name
,
framework
,
model
,
tokenizer
)
parent
=
dirname
(
output
)
parent
=
dirname
(
output
)
if
not
exists
(
parent
):
if
not
exists
(
parent
):
...
@@ -229,7 +234,15 @@ if __name__ == "__main__":
...
@@ -229,7 +234,15 @@ if __name__ == "__main__":
try
:
try
:
# Convert
# Convert
convert
(
args
.
framework
,
args
.
model
,
args
.
output
,
args
.
opset
,
args
.
tokenizer
,
args
.
use_external_format
)
convert
(
args
.
framework
,
args
.
model
,
args
.
output
,
args
.
opset
,
args
.
tokenizer
,
args
.
use_external_format
,
args
.
pipeline
,
)
# And verify
# And verify
if
args
.
check_loading
:
if
args
.
check_loading
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment