Unverified Commit 37a9fc49 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Choose framework for ONNX export (#16018)

* Can choose framework for ONNX export

* Fix docstring
parent 3f8360a7
...@@ -38,6 +38,9 @@ def main(): ...@@ -38,6 +38,9 @@ def main():
parser.add_argument( parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
) )
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
# Retrieve CLI arguments # Retrieve CLI arguments
...@@ -58,7 +61,7 @@ def main(): ...@@ -58,7 +61,7 @@ def main():
raise ValueError(f"Unsupported model type: {config.model_type}") raise ValueError(f"Unsupported model type: {config.model_type}")
# Allocate the model # Allocate the model
model = FeaturesManager.get_model_from_feature(args.feature, args.model) model = FeaturesManager.get_model_from_feature(args.feature, args.model, framework=args.framework)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config) onnx_config = model_onnx_config(model.config)
......
...@@ -37,7 +37,7 @@ if is_torch_available(): ...@@ -37,7 +37,7 @@ if is_torch_available():
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForTokenClassification,
) )
elif is_tf_available(): if is_tf_available():
from transformers.models.auto import ( from transformers.models.auto import (
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
...@@ -48,7 +48,7 @@ elif is_tf_available(): ...@@ -48,7 +48,7 @@ elif is_tf_available():
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
) )
else: if not is_torch_available() and not is_tf_available():
logger.warning( logger.warning(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed." "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
) )
...@@ -82,6 +82,8 @@ def supported_features_mapping( ...@@ -82,6 +82,8 @@ def supported_features_mapping(
class FeaturesManager: class FeaturesManager:
_TASKS_TO_AUTOMODELS = {}
_TASKS_TO_TF_AUTOMODELS = {}
if is_torch_available(): if is_torch_available():
_TASKS_TO_AUTOMODELS = { _TASKS_TO_AUTOMODELS = {
"default": AutoModel, "default": AutoModel,
...@@ -94,8 +96,8 @@ class FeaturesManager: ...@@ -94,8 +96,8 @@ class FeaturesManager:
"question-answering": AutoModelForQuestionAnswering, "question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification, "image-classification": AutoModelForImageClassification,
} }
elif is_tf_available(): if is_tf_available():
_TASKS_TO_AUTOMODELS = { _TASKS_TO_TF_AUTOMODELS = {
"default": TFAutoModel, "default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM, "masked-lm": TFAutoModelForMaskedLM,
"causal-lm": TFAutoModelForCausalLM, "causal-lm": TFAutoModelForCausalLM,
...@@ -105,8 +107,6 @@ class FeaturesManager: ...@@ -105,8 +107,6 @@ class FeaturesManager:
"multiple-choice": TFAutoModelForMultipleChoice, "multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering, "question-answering": TFAutoModelForQuestionAnswering,
} }
else:
_TASKS_TO_AUTOMODELS = {}
# Set of model topologies we support associated to the features supported by each topology and the factory # Set of model topologies we support associated to the features supported by each topology and the factory
_SUPPORTED_MODEL_TYPE = { _SUPPORTED_MODEL_TYPE = {
...@@ -257,11 +257,13 @@ class FeaturesManager: ...@@ -257,11 +257,13 @@ class FeaturesManager:
model_type: str, model_name: Optional[str] = None model_type: str, model_name: Optional[str] = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]: ) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
""" """
Try to retrieve the feature -> OnnxConfig constructor map from the model type. Tries to retrieve the feature -> OnnxConfig constructor map from the model type.
Args: Args:
model_type: The model type to retrieve the supported features for. model_type (`str`):
model_name: The name attribute of the model object, only used for the exception message. The model type to retrieve the supported features for.
model_name (`str`, *optional*):
The name attribute of the model object, only used for the exception message.
Returns: Returns:
The dictionary mapping each feature to a corresponding OnnxConfig constructor. The dictionary mapping each feature to a corresponding OnnxConfig constructor.
...@@ -281,45 +283,73 @@ class FeaturesManager: ...@@ -281,45 +283,73 @@ class FeaturesManager:
return feature.replace("-with-past", "") return feature.replace("-with-past", "")
@staticmethod @staticmethod
def get_model_class_for_feature(feature: str) -> Type: def _validate_framework_choice(framework: str):
"""
Validates if the framework requested for the export is both correct and available, otherwise throws an
exception.
"""
if framework not in ["pt", "tf"]:
raise ValueError(
f"Only two frameworks are supported for ONNX export: pt or tf, but {framework} was provided."
)
elif framework == "pt" and not is_torch_available():
raise RuntimeError("Cannot export model to ONNX using PyTorch because no PyTorch package was found.")
elif framework == "tf" and not is_tf_available():
raise RuntimeError("Cannot export model to ONNX using TensorFlow because no TensorFlow package was found.")
@staticmethod
def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
""" """
Attempt to retrieve an AutoModel class from a feature name. Attempts to retrieve an AutoModel class from a feature name.
Args: Args:
feature: The feature required. feature (`str`):
The feature required.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns: Returns:
The AutoModel class corresponding to the feature. The AutoModel class corresponding to the feature.
""" """
task = FeaturesManager.feature_to_task(feature) task = FeaturesManager.feature_to_task(feature)
if task not in FeaturesManager._TASKS_TO_AUTOMODELS: FeaturesManager._validate_framework_choice(framework)
if framework == "pt":
task_to_automodel = FeaturesManager._TASKS_TO_AUTOMODELS
else:
task_to_automodel = FeaturesManager._TASKS_TO_TF_AUTOMODELS
if task not in task_to_automodel:
raise KeyError( raise KeyError(
f"Unknown task: {feature}. " f"Unknown task: {feature}. "
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
) )
return FeaturesManager._TASKS_TO_AUTOMODELS[task] return task_to_automodel[task]
def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]: def get_model_from_feature(
feature: str, model: str, framework: str = "pt"
) -> Union[PreTrainedModel, TFPreTrainedModel]:
""" """
Attempt to retrieve a model from a model's name and the feature to be enabled. Attempts to retrieve a model from a model's name and the feature to be enabled.
Args: Args:
feature: The feature required. feature (`str`):
model: The name of the model to export. The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
Returns: Returns:
The instance of the model. The instance of the model.
""" """
# If PyTorch and TensorFlow are installed in the same environment, we model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
# load an AutoModel class by default
model_class = FeaturesManager.get_model_class_for_feature(feature)
try: try:
model = model_class.from_pretrained(model) model = model_class.from_pretrained(model)
# Load TensorFlow weights in an AutoModel instance if PyTorch and
# TensorFlow are installed in the same environment
except OSError: except OSError:
model = model_class.from_pretrained(model, from_tf=True) if framework == "pt":
model = model_class.from_pretrained(model, from_tf=True)
else:
model = model_class.from_pretrained(model, from_pt=True)
return model return model
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment