Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0e1fce3c
Unverified
Commit
0e1fce3c
authored
Jun 25, 2020
by
Anthony MOI
Committed by
GitHub
Jun 25, 2020
Browse files
Fix convert_graph_to_onnx (#5230)
parent
5543efd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
7 deletions
+20
-7
src/transformers/convert_graph_to_onnx.py
src/transformers/convert_graph_to_onnx.py
+20
-7
No files found.
src/transformers/convert_graph_to_onnx.py
View file @
0e1fce3c
...
@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
...
@@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
return
input_vars
,
output_names
,
dynamic_axes
,
tokens
return
input_vars
,
output_names
,
dynamic_axes
,
tokens
def
load_graph_from_args
(
framework
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
)
->
Pipeline
:
def
load_graph_from_args
(
pipeline_name
:
str
,
framework
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
)
->
Pipeline
:
# If no tokenizer provided
# If no tokenizer provided
if
tokenizer
is
None
:
if
tokenizer
is
None
:
tokenizer
=
model
tokenizer
=
model
# Check the wanted framework is available
if
framework
==
"pt"
and
not
is_torch_available
():
raise
Exception
(
"Cannot convert because PyTorch is not installed. Please install torch first."
)
if
framework
==
"tf"
and
not
is_tf_available
():
raise
Exception
(
"Cannot convert because TF is not installed. Please install tensorflow first."
)
print
(
"Loading pipeline (model: {}, tokenizer: {})"
.
format
(
model
,
tokenizer
))
print
(
"Loading pipeline (model: {}, tokenizer: {})"
.
format
(
model
,
tokenizer
))
# Allocate tokenizer and model
# Allocate tokenizer and model
return
pipeline
(
args
.
pipeline
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
framework
)
return
pipeline
(
pipeline
_name
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
framework
)
def
convert_pytorch
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
,
use_external_format
:
bool
):
def
convert_pytorch
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
,
use_external_format
:
bool
):
...
@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
...
@@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
):
def
convert_tensorflow
(
nlp
:
Pipeline
,
opset
:
int
,
output
:
str
):
if
not
is_tf_available
():
if
not
is_tf_available
():
raise
Exception
(
raise
Exception
(
"Cannot convert because TF is not installed. Please install tensorflow first."
)
"Cannot convert {} because TF is not installed. Please install torch first."
.
format
(
args
.
model
)
)
print
(
"/!
\\
Please note TensorFlow doesn't support exporting model > 2Gb /!
\\
"
)
print
(
"/!
\\
Please note TensorFlow doesn't support exporting model > 2Gb /!
\\
"
)
...
@@ -187,11 +191,12 @@ def convert(
...
@@ -187,11 +191,12 @@ def convert(
opset
:
int
,
opset
:
int
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
use_external_format
:
bool
=
False
,
use_external_format
:
bool
=
False
,
pipeline_name
:
str
=
"feature-extraction"
,
):
):
print
(
"ONNX opset version set to: {}"
.
format
(
opset
))
print
(
"ONNX opset version set to: {}"
.
format
(
opset
))
# Load the pipeline
# Load the pipeline
nlp
=
load_graph_from_args
(
framework
,
model
,
tokenizer
)
nlp
=
load_graph_from_args
(
pipeline_name
,
framework
,
model
,
tokenizer
)
parent
=
dirname
(
output
)
parent
=
dirname
(
output
)
if
not
exists
(
parent
):
if
not
exists
(
parent
):
...
@@ -229,7 +234,15 @@ if __name__ == "__main__":
...
@@ -229,7 +234,15 @@ if __name__ == "__main__":
try
:
try
:
# Convert
# Convert
convert
(
args
.
framework
,
args
.
model
,
args
.
output
,
args
.
opset
,
args
.
tokenizer
,
args
.
use_external_format
)
convert
(
args
.
framework
,
args
.
model
,
args
.
output
,
args
.
opset
,
args
.
tokenizer
,
args
.
use_external_format
,
args
.
pipeline
,
)
# And verify
# And verify
if
args
.
check_loading
:
if
args
.
check_loading
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment