Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
37a9fc49
Unverified
Commit
37a9fc49
authored
Mar 14, 2022
by
Michael Benayoun
Committed by
GitHub
Mar 14, 2022
Browse files
Choose framework for ONNX export (#16018)
* Can choose framework for ONNX export * Fix docstring
parent
3f8360a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
25 deletions
+58
-25
src/transformers/onnx/__main__.py
src/transformers/onnx/__main__.py
+4
-1
src/transformers/onnx/features.py
src/transformers/onnx/features.py
+54
-24
No files found.
src/transformers/onnx/__main__.py
View file @
37a9fc49
...
...
@@ -38,6 +38,9 @@ def main():
parser
.
add_argument
(
"--atol"
,
type
=
float
,
default
=
None
,
help
=
"Absolute difference tolerence when validating the model."
)
parser
.
add_argument
(
"--framework"
,
type
=
str
,
choices
=
[
"pt"
,
"tf"
],
default
=
"pt"
,
help
=
"The framework to use for the ONNX export."
)
parser
.
add_argument
(
"output"
,
type
=
Path
,
help
=
"Path indicating where to store generated ONNX model."
)
# Retrieve CLI arguments
...
...
@@ -58,7 +61,7 @@ def main():
raise
ValueError
(
f
"Unsupported model type:
{
config
.
model_type
}
"
)
# Allocate the model
model
=
FeaturesManager
.
get_model_from_feature
(
args
.
feature
,
args
.
model
)
model
=
FeaturesManager
.
get_model_from_feature
(
args
.
feature
,
args
.
model
,
framework
=
args
.
framework
)
model_kind
,
model_onnx_config
=
FeaturesManager
.
check_supported_model_or_raise
(
model
,
feature
=
args
.
feature
)
onnx_config
=
model_onnx_config
(
model
.
config
)
...
...
src/transformers/onnx/features.py
View file @
37a9fc49
...
...
@@ -37,7 +37,7 @@ if is_torch_available():
AutoModelForSequenceClassification
,
AutoModelForTokenClassification
,
)
el
if
is_tf_available
():
if
is_tf_available
():
from
transformers.models.auto
import
(
TFAutoModel
,
TFAutoModelForCausalLM
,
...
...
@@ -48,7 +48,7 @@ elif is_tf_available():
TFAutoModelForSequenceClassification
,
TFAutoModelForTokenClassification
,
)
else
:
if
not
is_torch_available
()
and
not
is_tf_available
()
:
logger
.
warning
(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
)
...
...
@@ -82,6 +82,8 @@ def supported_features_mapping(
class
FeaturesManager
:
_TASKS_TO_AUTOMODELS
=
{}
_TASKS_TO_TF_AUTOMODELS
=
{}
if
is_torch_available
():
_TASKS_TO_AUTOMODELS
=
{
"default"
:
AutoModel
,
...
...
@@ -94,8 +96,8 @@ class FeaturesManager:
"question-answering"
:
AutoModelForQuestionAnswering
,
"image-classification"
:
AutoModelForImageClassification
,
}
el
if
is_tf_available
():
_TASKS_TO_AUTOMODELS
=
{
if
is_tf_available
():
_TASKS_TO_
TF_
AUTOMODELS
=
{
"default"
:
TFAutoModel
,
"masked-lm"
:
TFAutoModelForMaskedLM
,
"causal-lm"
:
TFAutoModelForCausalLM
,
...
...
@@ -105,8 +107,6 @@ class FeaturesManager:
"multiple-choice"
:
TFAutoModelForMultipleChoice
,
"question-answering"
:
TFAutoModelForQuestionAnswering
,
}
else
:
_TASKS_TO_AUTOMODELS
=
{}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE
=
{
...
...
@@ -257,11 +257,13 @@ class FeaturesManager:
model_type
:
str
,
model_name
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Callable
[[
PretrainedConfig
],
OnnxConfig
]]:
"""
Tr
y
to retrieve the feature -> OnnxConfig constructor map from the model type.
Tr
ies
to retrieve the feature -> OnnxConfig constructor map from the model type.
Args:
model_type: The model type to retrieve the supported features for.
model_name: The name attribute of the model object, only used for the exception message.
model_type (`str`):
The model type to retrieve the supported features for.
model_name (`str`, *optional*):
The name attribute of the model object, only used for the exception message.
Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
...
...
@@ -281,45 +283,73 @@ class FeaturesManager:
return
feature
.
replace
(
"-with-past"
,
""
)
@
staticmethod
def
get_model_class_for_feature
(
feature
:
str
)
->
Type
:
def
_validate_framework_choice
(
framework
:
str
):
"""
Validates if the framework requested for the export is both correct and available, otherwise throws an
exception.
"""
if
framework
not
in
[
"pt"
,
"tf"
]:
raise
ValueError
(
f
"Only two frameworks are supported for ONNX export: pt or tf, but
{
framework
}
was provided."
)
elif
framework
==
"pt"
and
not
is_torch_available
():
raise
RuntimeError
(
"Cannot export model to ONNX using PyTorch because no PyTorch package was found."
)
elif
framework
==
"tf"
and
not
is_tf_available
():
raise
RuntimeError
(
"Cannot export model to ONNX using TensorFlow because no TensorFlow package was found."
)
@
staticmethod
def
get_model_class_for_feature
(
feature
:
str
,
framework
:
str
=
"pt"
)
->
Type
:
"""
Attempt to retrieve an AutoModel class from a feature name.
Attempt
s
to retrieve an AutoModel class from a feature name.
Args:
feature: The feature required.
feature (`str`):
The feature required.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns:
The AutoModel class corresponding to the feature.
"""
task
=
FeaturesManager
.
feature_to_task
(
feature
)
if
task
not
in
FeaturesManager
.
_TASKS_TO_AUTOMODELS
:
FeaturesManager
.
_validate_framework_choice
(
framework
)
if
framework
==
"pt"
:
task_to_automodel
=
FeaturesManager
.
_TASKS_TO_AUTOMODELS
else
:
task_to_automodel
=
FeaturesManager
.
_TASKS_TO_TF_AUTOMODELS
if
task
not
in
task_to_automodel
:
raise
KeyError
(
f
"Unknown task:
{
feature
}
. "
f
"Possible values are
{
list
(
FeaturesManager
.
_TASKS_TO_AUTOMODELS
.
values
())
}
"
)
return
FeaturesManager
.
_TASKS_TO_AUTOMODELS
[
task
]
return
task_to_automodel
[
task
]
def
get_model_from_feature
(
feature
:
str
,
model
:
str
)
->
Union
[
PreTrainedModel
,
TFPreTrainedModel
]:
def
get_model_from_feature
(
feature
:
str
,
model
:
str
,
framework
:
str
=
"pt"
)
->
Union
[
PreTrainedModel
,
TFPreTrainedModel
]:
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Attempt
s
to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required.
model: The name of the model to export.
feature (`str`):
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns:
The instance of the model.
"""
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class
=
FeaturesManager
.
get_model_class_for_feature
(
feature
)
model_class
=
FeaturesManager
.
get_model_class_for_feature
(
feature
,
framework
)
try
:
model
=
model_class
.
from_pretrained
(
model
)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except
OSError
:
model
=
model_class
.
from_pretrained
(
model
,
from_tf
=
True
)
if
framework
==
"pt"
:
model
=
model_class
.
from_pretrained
(
model
,
from_tf
=
True
)
else
:
model
=
model_class
.
from_pretrained
(
model
,
from_pt
=
True
)
return
model
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment