Unverified Commit 37a9fc49 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Choose framework for ONNX export (#16018)

* Can choose framework for ONNX export

* Fix docstring
parent 3f8360a7
......@@ -38,6 +38,9 @@ def main():
parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
# Retrieve CLI arguments
......@@ -58,7 +61,7 @@ def main():
raise ValueError(f"Unsupported model type: {config.model_type}")
# Allocate the model
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model, framework=args.framework)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)
......
......@@ -37,7 +37,7 @@ if is_torch_available():
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
elif is_tf_available():
if is_tf_available():
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
......@@ -48,7 +48,7 @@ elif is_tf_available():
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
else:
if not is_torch_available() and not is_tf_available():
logger.warning(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
)
......@@ -82,6 +82,8 @@ def supported_features_mapping(
class FeaturesManager:
_TASKS_TO_AUTOMODELS = {}
_TASKS_TO_TF_AUTOMODELS = {}
if is_torch_available():
_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
......@@ -94,8 +96,8 @@ class FeaturesManager:
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
}
elif is_tf_available():
_TASKS_TO_AUTOMODELS = {
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM,
......@@ -105,8 +107,6 @@ class FeaturesManager:
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
}
else:
_TASKS_TO_AUTOMODELS = {}
# Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = {
......@@ -257,11 +257,13 @@ class FeaturesManager:
model_type: str, model_name: Optional[str] = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Try to retrieve the feature -> OnnxConfig constructor map from the model type.
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
Args:
model_type: The model type to retrieve the supported features for.
model_name: The name attribute of the model object, only used for the exception message.
model_type (`str`):
The model type to retrieve the supported features for.
model_name (`str`, *optional*):
The name attribute of the model object, only used for the exception message.
Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor.
......@@ -281,45 +283,73 @@ class FeaturesManager:
return feature.replace("-with-past", "")
@staticmethod
def get_model_class_for_feature(feature: str) -> Type:
def _validate_framework_choice(framework: str):
"""
Validates if the framework requested for the export is both correct and available, otherwise throws an
exception.
"""
if framework not in ["pt", "tf"]:
raise ValueError(
f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
)
elif framework == "pt" and not is_torch_available():
raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
elif framework == "tf" and not is_tf_available():
raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")
@staticmethod
def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
"""
Attempt to retrieve an AutoModel class from a feature name.
Attempts to retrieve an AutoModel class from a feature name.
Args:
feature: The feature required.
feature (`str`):
The feature required.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns:
The AutoModel class corresponding to the feature.
"""
task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS:
FeaturesManager._validate_framework_choice(framework)
if framework == "pt":
task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
else:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError(
f"Unknown task: {feature}. "
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
return task_to_automodel[task]
def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
def get_model_from_feature(
feature: str, model: str, framework: str = "pt"
) -> Union[PreTrainedModel, TFPreTrainedModel]:
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.
Attempts to retrieve a model from a model's name and the feature to be enabled.
Args:
feature: The feature required.
model: The name of the model to export.
feature (`str`):
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns:
The instance of the model.
"""
# If PyTorch and TensorFlow are installed in the same environment, we
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError:
if framework == "pt":
model = model_class.from_pretrained(model, from_tf=True)
else:
model = model_class.from_pretrained(model, from_pt=True)
return model
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment