Unverified Commit 640421c0 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)

* Raise an issue if the pytorch version is < 1.8.0

* Attempt to add a test to ensure it correctly raises.

* Missing docstring.

* Second attempt, patch with string absolute import.

* Let's do the call before checking it was called ...

* use the correct function ... 🤦

* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.

* Correct path mock patching

* relax constraint for torch_onnx_dict_inputs to ge instead of eq.

* Style.

* Split each version requirements for torch.

* Let's compare version directly.

* Import torch_version after checking pytorch is installed.

* @require_torch
parent 9160d81c
...@@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = { ...@@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
} }
# This is the version of torch required to run torch.fx features. # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.8") TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
...@@ -297,7 +298,7 @@ def is_torch_cuda_available(): ...@@ -297,7 +298,7 @@ def is_torch_cuda_available():
return False return False
_torch_fx_available = False _torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available: if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch")) torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) == ( _torch_fx_available = (torch_version.major, torch_version.minor) == (
...@@ -305,11 +306,17 @@ if _torch_available: ...@@ -305,11 +306,17 @@ if _torch_available:
TORCH_FX_REQUIRED_VERSION.minor, TORCH_FX_REQUIRED_VERSION.minor,
) )
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
def is_torch_fx_available(): def is_torch_fx_available():
return _torch_fx_available return _torch_fx_available
def is_torch_onnx_dict_inputs_support_available():
return _torch_onnx_dict_inputs_support_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
from packaging.version import Version, parse from packaging.version import Version, parse
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..utils import logging from ..utils import logging
from .config import OnnxConfig from .config import OnnxConfig
from .utils import flatten_output_collection_property from .utils import flatten_output_collection_property
...@@ -79,11 +80,16 @@ def export( ...@@ -79,11 +80,16 @@ def export(
""" """
if not is_torch_available(): if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch import torch
from torch.onnx import export from torch.onnx import export
from ..file_utils import torch_version
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
logger.info(f"Using framework PyTorch: {torch.__version__}") logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
model.config.return_dict = True model.config.return_dict = True
......
...@@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig ...@@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig # from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
ParameterFormat,
export,
validate_model_outputs,
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import ( from transformers.onnx.utils import (
compute_effective_axis_dimension, compute_effective_axis_dimension,
...@@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase): ...@@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
Cover all the utilities involved to export ONNX models Cover all the utilities involved to export ONNX models
""" """
@require_torch
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
"""
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
"""
self.assertRaises(AssertionError, export, None, None, None, None, None)
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
def test_compute_effective_axis_dimension(self): def test_compute_effective_axis_dimension(self):
""" """
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1. When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment