Unverified Commit 83eda643 authored by Alessandro Pietro Bardelli's avatar Alessandro Pietro Bardelli Committed by GitHub
Browse files

Better check for packages availability (#23163)

* Better check for packages availability

* amend _optimumneuron_available

* amend torch_version

* amend PIL detection and lint

* lint

* amend _faiss_available

* remove overloaded signatures of _is_package_available

* fix sklearn and decord detection

* remove unused checks

* revert
parent d51296d9
...@@ -72,6 +72,7 @@ from .utils import ( ...@@ -72,6 +72,7 @@ from .utils import (
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_full_repo_name, get_full_repo_name,
get_torch_version,
has_file, has_file,
http_user_agent, http_user_agent,
is_apex_available, is_apex_available,
...@@ -125,5 +126,4 @@ from .utils import ( ...@@ -125,5 +126,4 @@ from .utils import (
to_numpy, to_numpy,
to_py_obj, to_py_obj,
torch_only_method, torch_only_method,
torch_version,
) )
...@@ -232,9 +232,9 @@ class OnnxConfig(ABC): ...@@ -232,9 +232,9 @@ class OnnxConfig(ABC):
`bool`: Whether the installed version of PyTorch is compatible with the model. `bool`: Whether the installed version of PyTorch is compatible with the model.
""" """
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
return torch_version >= self.torch_onnx_minimum_version return get_torch_version() >= self.torch_onnx_minimum_version
else: else:
return False return False
......
...@@ -334,12 +334,12 @@ def export( ...@@ -334,12 +334,12 @@ def export(
preprocessor = tokenizer preprocessor = tokenizer
if is_torch_available(): if is_torch_available():
from ..utils import torch_version from ..utils import get_torch_version
if not config.is_torch_support_available: if not config.is_torch_support_available:
logger.warning( logger.warning(
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}," f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
f" got: {torch_version}" f" got: {get_torch_version()}"
) )
if is_torch_available() and issubclass(type(model), PreTrainedModel): if is_torch_available() and issubclass(type(model), PreTrainedModel):
......
...@@ -99,6 +99,7 @@ from .import_utils import ( ...@@ -99,6 +99,7 @@ from .import_utils import (
_LazyModule, _LazyModule,
ccl_version, ccl_version,
direct_transformers_import, direct_transformers_import,
get_torch_version,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
...@@ -170,7 +171,6 @@ from .import_utils import ( ...@@ -170,7 +171,6 @@ from .import_utils import (
is_vision_available, is_vision_available,
requires_backends, requires_backends,
torch_only_method, torch_only_method,
torch_version,
) )
......
...@@ -25,7 +25,6 @@ import warnings ...@@ -25,7 +25,6 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
...@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import ( ...@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available from ..utils import (
from ..utils.versions import importlib_metadata ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available(): if is_peft_available():
...@@ -737,9 +741,8 @@ class HFTracer(Tracer): ...@@ -737,9 +741,8 @@ class HFTracer(Tracer):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available(): if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
raise ImportError( raise ImportError(
f"Found an incompatible version of torch. Found version {torch_version}, but only version " f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported." f"{TORCH_FX_REQUIRED_VERSION} is supported."
) )
......
This diff is collapsed.
...@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config) onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version: if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
) )
preprocessor = get_preprocessor(model_name) preprocessor = get_preprocessor(model_name)
...@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config) onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version: if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
) )
encoder_model = model.get_encoder() encoder_model = model.get_encoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment