"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d7f33a73d94f462f63878864b88dc0cd96e5a030"
Unverified Commit 83eda643 authored by Alessandro Pietro Bardelli's avatar Alessandro Pietro Bardelli Committed by GitHub
Browse files

Better check for packages availability (#23163)

* Better check for packages availability

* amend _optimumneuron_available

* amend torch_version

* amend PIL detection and lint

* lint

* amend _faiss_available

* remove overloaded signatures of _is_package_available

* fix sklearn and decord detection

* remove unused checks

* revert
parent d51296d9
...@@ -72,6 +72,7 @@ from .utils import ( ...@@ -72,6 +72,7 @@ from .utils import (
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_full_repo_name, get_full_repo_name,
get_torch_version,
has_file, has_file,
http_user_agent, http_user_agent,
is_apex_available, is_apex_available,
...@@ -125,5 +126,4 @@ from .utils import ( ...@@ -125,5 +126,4 @@ from .utils import (
to_numpy, to_numpy,
to_py_obj, to_py_obj,
torch_only_method, torch_only_method,
torch_version,
) )
...@@ -232,9 +232,9 @@ class OnnxConfig(ABC): ...@@ -232,9 +232,9 @@ class OnnxConfig(ABC):
`bool`: Whether the installed version of PyTorch is compatible with the model. `bool`: Whether the installed version of PyTorch is compatible with the model.
""" """
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
return torch_version >= self.torch_onnx_minimum_version return get_torch_version() >= self.torch_onnx_minimum_version
else: else:
return False return False
......
...@@ -334,12 +334,12 @@ def export( ...@@ -334,12 +334,12 @@ def export(
preprocessor = tokenizer preprocessor = tokenizer
if is_torch_available(): if is_torch_available():
from ..utils import torch_version from ..utils import get_torch_version
if not config.is_torch_support_available: if not config.is_torch_support_available:
logger.warning( logger.warning(
f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version}," f"Unsupported PyTorch version for this model. Minimum required is {config.torch_onnx_minimum_version},"
f" got: {torch_version}" f" got: {get_torch_version()}"
) )
if is_torch_available() and issubclass(type(model), PreTrainedModel): if is_torch_available() and issubclass(type(model), PreTrainedModel):
......
...@@ -99,6 +99,7 @@ from .import_utils import ( ...@@ -99,6 +99,7 @@ from .import_utils import (
_LazyModule, _LazyModule,
ccl_version, ccl_version,
direct_transformers_import, direct_transformers_import,
get_torch_version,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
...@@ -170,7 +171,6 @@ from .import_utils import ( ...@@ -170,7 +171,6 @@ from .import_utils import (
is_vision_available, is_vision_available,
requires_backends, requires_backends,
torch_only_method, torch_only_method,
torch_version,
) )
......
...@@ -25,7 +25,6 @@ import warnings ...@@ -25,7 +25,6 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
...@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import ( ...@@ -54,8 +53,13 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_peft_available, is_torch_fx_available from ..utils import (
from ..utils.versions import importlib_metadata ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available(): if is_peft_available():
...@@ -737,9 +741,8 @@ class HFTracer(Tracer): ...@@ -737,9 +741,8 @@ class HFTracer(Tracer):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available(): if not is_torch_fx_available():
torch_version = version.parse(importlib_metadata.version("torch"))
raise ImportError( raise ImportError(
f"Found an incompatible version of torch. Found version {torch_version}, but only version " f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported." f"{TORCH_FX_REQUIRED_VERSION} is supported."
) )
......
...@@ -25,7 +25,7 @@ from collections import OrderedDict ...@@ -25,7 +25,7 @@ from collections import OrderedDict
from functools import lru_cache from functools import lru_cache
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from typing import Any from typing import Any, Tuple, Union
from packaging import version from packaging import version
...@@ -35,6 +35,24 @@ from .versions import importlib_metadata ...@@ -35,6 +35,24 @@ from .versions import importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib_metadata.version(pkg_name)
package_exists = True
except importlib_metadata.PackageNotFoundError:
package_exists = False
logger.debug(f"Detected {pkg_name} version {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
...@@ -44,26 +62,80 @@ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() ...@@ -44,26 +62,80 @@ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_bs4_available = _is_package_available("bs4")
_coloredlogs_available = _is_package_available("coloredlogs")
_datasets_available = _is_package_available("datasets")
_decord_available = importlib.util.find_spec("decord") is not None
_detectron2_available = _is_package_available("detectron2")
_faiss_available = _is_package_available("faiss") or _is_package_available("faiss-cpu")
_ftfy_available = _is_package_available("ftfy")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_kenlm_available = _is_package_available("kenlm")
_keras_nlp_available = _is_package_available("keras_nlp")
_librosa_available = _is_package_available("librosa")
_natten_available = _is_package_available("natten")
_ninja_available = _is_package_available("ninja")
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_optimumneuron_available = _optimum_available and _is_package_available("optimum.neuron")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
_psutil_available = _is_package_available("psutil")
_py3nvml_available = _is_package_available("py3nvml")
_pyctcdecode_available = _is_package_available("pyctcdecode")
_pytesseract_available = _is_package_available("pytesseract")
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
_rjieba_available = _is_package_available("rjieba")
_sacremoses_available = _is_package_available("sacremoses")
_safetensors_available = _is_package_available("safetensors")
_scipy_available = _is_package_available("scipy")
_sentencepiece_available = _is_package_available("sentencepiece")
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
try:
importlib_metadata.version("scikit-learn")
except importlib_metadata.PackageNotFoundError:
_sklearn_available = False
_smdistributed_available = _is_package_available("smdistributed")
_soundfile_available = _is_package_available("soundfile")
_spacy_available = _is_package_available("spacy")
_sudachipy_available = _is_package_available("sudachipy")
_tensorflow_probability_available = _is_package_available("tensorflow_probability")
_tensorflow_text_available = _is_package_available("tensorflow_text")
_tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None _torch_available, _torch_version = _is_package_available("torch", return_version=True)
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else: else:
logger.info("Disabling PyTorch because USE_TF is set") logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False _torch_available = False
_tf_version = "N/A" _tf_version = "N/A"
_tf_available = False
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
_tf_available = True _tf_available = True
else: else:
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None _tf_available = _is_package_available("tensorflow")
if _tf_available: if _tf_available:
candidates = ( candidates = (
"tensorflow", "tensorflow",
...@@ -93,179 +165,9 @@ else: ...@@ -93,179 +165,9 @@ else:
f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
) )
_tf_available = False _tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else: else:
logger.info("Disabling Tensorflow because USE_TORCH is set") logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
_datasets_available = importlib.util.find_spec("datasets") is not None
try:
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
_ = importlib_metadata.version("datasets")
_datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
_datasets_available = False
except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_diffusers_available = importlib.util.find_spec("diffusers") is not None
try:
_diffusers_version = importlib_metadata.version("diffusers")
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
_detectron2_version = importlib_metadata.version("detectron2")
logger.debug(f"Successfully imported detectron2 version {_detectron2_version}")
except importlib_metadata.PackageNotFoundError:
_detectron2_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
try:
_faiss_version = importlib_metadata.version("faiss-cpu")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_ftfy_available = importlib.util.find_spec("ftfy") is not None
try:
_ftfy_version = importlib_metadata.version("ftfy")
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
except importlib_metadata.PackageNotFoundError:
_ftfy_available = False
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
try:
_coloredlogs_available = importlib_metadata.version("coloredlogs")
logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
except importlib_metadata.PackageNotFoundError:
_coloredlogs_available = False
sympy_available = importlib.util.find_spec("sympy") is not None
try:
_sympy_available = importlib_metadata.version("sympy")
logger.debug(f"Successfully imported sympy version {_sympy_available}")
except importlib_metadata.PackageNotFoundError:
_sympy_available = False
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
try:
_tf2onnx_version = importlib_metadata.version("tf2onnx")
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
except importlib_metadata.PackageNotFoundError:
_tf2onnx_available = False
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_opencv_available = importlib.util.find_spec("cv2") is not None
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
try:
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
logger.debug(f"Successfully imported pytorch-quantization version {_pytorch_quantization_version}")
except importlib_metadata.PackageNotFoundError:
_pytorch_quantization_available = False
_soundfile_available = importlib.util.find_spec("soundfile") is not None
try:
_soundfile_version = importlib_metadata.version("soundfile")
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False
_tensorflow_probability_available = importlib.util.find_spec("tensorflow_probability") is not None
try:
_tensorflow_probability_version = importlib_metadata.version("tensorflow_probability")
logger.debug(f"Successfully imported tensorflow-probability version {_tensorflow_probability_version}")
except importlib_metadata.PackageNotFoundError:
_tensorflow_probability_available = False
_timm_available = importlib.util.find_spec("timm") is not None
try:
_timm_version = importlib_metadata.version("timm")
logger.debug(f"Successfully imported timm version {_timm_version}")
except importlib_metadata.PackageNotFoundError:
_timm_available = False
_natten_available = importlib.util.find_spec("natten") is not None
try:
_natten_version = importlib_metadata.version("natten")
logger.debug(f"Successfully imported natten version {_natten_version}")
except importlib_metadata.PackageNotFoundError:
_natten_available = False
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
try:
_torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False
_phonemizer_available = importlib.util.find_spec("phonemizer") is not None
try:
_phonemizer_version = importlib_metadata.version("phonemizer")
logger.debug(f"Successfully imported phonemizer version {_phonemizer_version}")
except importlib_metadata.PackageNotFoundError:
_phonemizer_available = False
_pyctcdecode_available = importlib.util.find_spec("pyctcdecode") is not None
try:
_pyctcdecode_version = importlib_metadata.version("pyctcdecode")
logger.debug(f"Successfully imported pyctcdecode version {_pyctcdecode_version}")
except importlib_metadata.PackageNotFoundError:
_pyctcdecode_available = False
_librosa_available = importlib.util.find_spec("librosa") is not None
try:
_librosa_version = importlib_metadata.version("librosa")
logger.debug(f"Successfully imported librosa version {_librosa_version}")
except importlib_metadata.PackageNotFoundError:
_librosa_available = False
ccl_version = "N/A" ccl_version = "N/A"
_is_ccl_available = ( _is_ccl_available = (
...@@ -274,38 +176,46 @@ _is_ccl_available = ( ...@@ -274,38 +176,46 @@ _is_ccl_available = (
) )
try: try:
ccl_version = importlib_metadata.version("oneccl_bind_pt") ccl_version = importlib_metadata.version("oneccl_bind_pt")
logger.debug(f"Successfully imported oneccl_bind_pt version {ccl_version}") logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_is_ccl_available = False _is_ccl_available = False
_decord_availale = importlib.util.find_spec("decord") is not None
try:
_decord_version = importlib_metadata.version("decord")
logger.debug(f"Successfully imported decord version {_decord_version}")
except importlib_metadata.PackageNotFoundError:
_decord_availale = False
_jieba_available = importlib.util.find_spec("jieba") is not None _flax_available = False
try: if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_jieba_version = importlib_metadata.version("jieba") _flax_available, _flax_version = _is_package_available("flax", return_version=True)
logger.debug(f"Successfully imported jieba version {_jieba_version}") if _flax_available:
except importlib_metadata.PackageNotFoundError: _jax_available, _jax_version = _is_package_available("jax", return_version=True)
_jieba_available = False if _jax_available:
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
else:
_flax_available = _jax_available = False
_jax_version = _flax_version = "N/A"
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.10") _torch_fx_available = False
if _torch_available:
torch_version = version.parse(_torch_version)
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
def is_kenlm_available(): def is_kenlm_available():
return importlib.util.find_spec("kenlm") is not None return _kenlm_available
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
def get_torch_version():
return _torch_version
def is_torchvision_available(): def is_torchvision_available():
return importlib.util.find_spec("torchvision") is not None return _torchvision_available
def is_pyctcdecode_available(): def is_pyctcdecode_available():
...@@ -404,26 +314,16 @@ def is_torch_tf32_available(): ...@@ -404,26 +314,16 @@ def is_torch_tf32_available():
return True return True
torch_version = None
_torch_fx_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
_torch_fx_available = (torch_version.major, torch_version.minor) >= (
TORCH_FX_REQUIRED_VERSION.major,
TORCH_FX_REQUIRED_VERSION.minor,
)
def is_torch_fx_available(): def is_torch_fx_available():
return _torch_fx_available return _torch_fx_available
def is_peft_available(): def is_peft_available():
return importlib.util.find_spec("peft") is not None return _peft_available
def is_bs4_available(): def is_bs4_available():
return importlib.util.find_spec("bs4") is not None return _bs4_available
def is_tf_available(): def is_tf_available():
...@@ -443,7 +343,7 @@ def is_onnx_available(): ...@@ -443,7 +343,7 @@ def is_onnx_available():
def is_openai_available(): def is_openai_available():
return importlib.util.find_spec("openai") is not None return _openai_available
def is_flax_available(): def is_flax_available():
...@@ -517,40 +417,36 @@ def is_detectron2_available(): ...@@ -517,40 +417,36 @@ def is_detectron2_available():
def is_rjieba_available(): def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None return _rjieba_available
def is_psutil_available(): def is_psutil_available():
return importlib.util.find_spec("psutil") is not None return _psutil_available
def is_py3nvml_available(): def is_py3nvml_available():
return importlib.util.find_spec("py3nvml") is not None return _py3nvml_available
def is_sacremoses_available(): def is_sacremoses_available():
return importlib.util.find_spec("sacremoses") is not None return _sacremoses_available
def is_apex_available(): def is_apex_available():
return importlib.util.find_spec("apex") is not None return _apex_available
def is_ninja_available(): def is_ninja_available():
return importlib.util.find_spec("ninja") is not None return _ninja_available
def is_ipex_available(): def is_ipex_available():
def get_major_and_minor_from_version(full_version): def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None: if not is_torch_available() or not _ipex_available:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
return False return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version) torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor: if torch_major_and_minor != ipex_major_and_minor:
...@@ -563,11 +459,11 @@ def is_ipex_available(): ...@@ -563,11 +459,11 @@ def is_ipex_available():
def is_bitsandbytes_available(): def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None return _bitsandbytes_available
def is_torchdistx_available(): def is_torchdistx_available():
return importlib.util.find_spec("torchdistx") is not None return _torchdistx_available
def is_faiss_available(): def is_faiss_available():
...@@ -575,17 +471,15 @@ def is_faiss_available(): ...@@ -575,17 +471,15 @@ def is_faiss_available():
def is_scipy_available(): def is_scipy_available():
return importlib.util.find_spec("scipy") is not None return _scipy_available
def is_sklearn_available(): def is_sklearn_available():
if importlib.util.find_spec("sklearn") is None: return _sklearn_available
return False
return is_scipy_available() and importlib.util.find_spec("sklearn.metrics")
def is_sentencepiece_available(): def is_sentencepiece_available():
return importlib.util.find_spec("sentencepiece") is not None return _sentencepiece_available
def is_protobuf_available(): def is_protobuf_available():
...@@ -595,56 +489,54 @@ def is_protobuf_available(): ...@@ -595,56 +489,54 @@ def is_protobuf_available():
def is_accelerate_available(check_partial_state=False): def is_accelerate_available(check_partial_state=False):
accelerate_available = importlib.util.find_spec("accelerate") is not None if check_partial_state:
if accelerate_available: return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.17.0")
if check_partial_state: return _accelerate_available
return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0")
else:
return True
else:
return False
def is_optimum_available(): def is_optimum_available():
return importlib.util.find_spec("optimum") is not None return _optimum_available
def is_optimum_neuron_available(): def is_optimum_neuron_available():
return importlib.util.find_spec("optimum.neuron") is not None return _optimumneuron_available
def is_safetensors_available(): def is_safetensors_available():
if is_torch_available(): if is_torch_available() and version.parse(_torch_version) < version.parse("1.10"):
if version.parse(_torch_version) >= version.parse("1.10"): return False
return importlib.util.find_spec("safetensors") is not None return _safetensors_available
else:
return False
else:
return importlib.util.find_spec("safetensors") is not None
def is_tokenizers_available(): def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None return _tokenizers_available
def is_vision_available(): def is_vision_available():
return importlib.util.find_spec("PIL") is not None _pil_available = importlib.util.find_spec("PIL") is not None
if _pil_available:
try:
package_version = importlib_metadata.version("Pillow")
except importlib_metadata.PackageNotFoundError:
return False
logger.debug(f"Detected PIL version {package_version}")
return _pil_available
def is_pytesseract_available(): def is_pytesseract_available():
return importlib.util.find_spec("pytesseract") is not None return _pytesseract_available
def is_spacy_available(): def is_spacy_available():
return importlib.util.find_spec("spacy") is not None return _spacy_available
def is_tensorflow_text_available(): def is_tensorflow_text_available():
return is_tf_available() and importlib.util.find_spec("tensorflow_text") is not None return is_tf_available() and _tensorflow_text_available
def is_keras_nlp_available(): def is_keras_nlp_available():
return is_tensorflow_text_available() and importlib.util.find_spec("keras_nlp") is not None return is_tensorflow_text_available() and _keras_nlp_available
def is_in_notebook(): def is_in_notebook():
...@@ -674,7 +566,7 @@ def is_tensorflow_probability_available(): ...@@ -674,7 +566,7 @@ def is_tensorflow_probability_available():
def is_pandas_available(): def is_pandas_available():
return importlib.util.find_spec("pandas") is not None return _pandas_available
def is_sagemaker_dp_enabled(): def is_sagemaker_dp_enabled():
...@@ -688,7 +580,7 @@ def is_sagemaker_dp_enabled(): ...@@ -688,7 +580,7 @@ def is_sagemaker_dp_enabled():
except json.JSONDecodeError: except json.JSONDecodeError:
return False return False
# Lastly, check if the `smdistributed` module is present. # Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None return _smdistributed_available
def is_sagemaker_mp_enabled(): def is_sagemaker_mp_enabled():
...@@ -712,7 +604,7 @@ def is_sagemaker_mp_enabled(): ...@@ -712,7 +604,7 @@ def is_sagemaker_mp_enabled():
except json.JSONDecodeError: except json.JSONDecodeError:
return False return False
# Lastly, check if the `smdistributed` module is present. # Lastly, check if the `smdistributed` module is present.
return importlib.util.find_spec("smdistributed") is not None return _smdistributed_available
def is_training_run_on_sagemaker(): def is_training_run_on_sagemaker():
...@@ -762,11 +654,11 @@ def is_ccl_available(): ...@@ -762,11 +654,11 @@ def is_ccl_available():
def is_decord_available(): def is_decord_available():
return _decord_availale return _decord_available
def is_sudachi_available(): def is_sudachi_available():
return importlib.util.find_spec("sudachipy") is not None return _sudachipy_available
def is_jumanpp_available(): def is_jumanpp_available():
......
...@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -319,12 +319,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config) onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version: if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
) )
preprocessor = get_preprocessor(model_name) preprocessor = get_preprocessor(model_name)
...@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -362,12 +362,12 @@ class OnnxExportTestCaseV2(TestCase):
onnx_config = onnx_config_class_constructor(model.config) onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available(): if is_torch_available():
from transformers.utils import torch_version from transformers.utils import get_torch_version
if torch_version < onnx_config.torch_onnx_minimum_version: if get_torch_version() < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
) )
encoder_model = model.get_encoder() encoder_model = model.get_encoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment