Unverified Commit 83eda643 authored by Alessandro Pietro Bardelli's avatar Alessandro Pietro Bardelli Committed by GitHub
Browse files

Better check for packages availability (#23163)

* Better check for packages availability

* amend _optimumneuron_available

* amend torch_version

* amend PIL detection and lint

* lint

* amend _faiss_available

* remove overloaded signatures of _is_package_available

* fix sklearn and decord detection

* remove unused checks

* revert
parent d51296d9
......@@ -72,6 +72,7 @@ from .utils import (
get_cached_models,
get_file_from_repo,
get_full_repo_name,
get_torch_version,
has_file,
http_user_agent,
is_apex_available,
......@@ -125,5 +126,4 @@ from .utils import (
to_numpy,
to_py_obj,
torch_only_method,
torch_version,
)
......@@ -232,9 +232,9 @@ class OnnxConfig(ABC):
`bool`: Whether the installed version of PyTorch is compatible with the model.
"""
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
return torch_version >= self.torch_onnx_minimum_version
return get_torch_version() >= self.torch_onnx_minimum_version
else:
return False
......
......@@ -334,12 +334,12 @@ def export(
preprocessor = tokenizer
if is_torch_available():
from ..utils import torch_version
from ..utils import get_torch_version
if not config.is_torch_support_available:
logger.warning(
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
f" got: {torch_version}"
f" got: {get_torch_version()}"
)
if is_torch_available() and issubclass(type(model), PreTrainedModel):
......
......@@ -99,6 +99,7 @@ from .import_utils import (
_LazyModule,
ccl_version,
direct_transformers_import,
get_torch_version,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
......@@ -170,7 +171,6 @@ from .import_utils import (
is_vision_available,
requires_backends,
torch_only_method,
torch_version,
)
......
......@@ -25,7 +25,6 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from packaging import version
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
......@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available
from ..utils.versions import importlib_metadata
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available():
......@@ -737,9 +741,8 @@ class HFTracer(Tracer):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
raise ImportError(
f"Found an incompatible version of torch. Found version {torch_version}, but only version "
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
......
This diff is collapsed.
......@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
)
preprocessor = get_preprocessor(model_name)
......@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
)
encoder_model = model.get_encoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment