Unverified Commit fb2b8984 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[file_utils] import refactor (#10859)

* import refactor

* fix the fallback
parent 3f48b2bc
...@@ -84,31 +84,24 @@ else: ...@@ -84,31 +84,24 @@ else:
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None _tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available: if _tf_available:
candidates = (
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
"intel-tensorflow",
)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu # For the metadata, we have to look for both tensorflow and tensorflow-cpu
try: for pkg in candidates:
_tf_version = importlib_metadata.version("tensorflow")
except importlib_metadata.PackageNotFoundError:
try: try:
_tf_version = importlib_metadata.version("tensorflow-cpu") _tf_version = importlib_metadata.version(pkg)
break
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
try: pass
_tf_version = importlib_metadata.version("tensorflow-gpu") _tf_available = _tf_version is not None
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly-cpu")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly-gpu")
except importlib_metadata.PackageNotFoundError:
# Support for intel-tensorflow version
try:
_tf_version = importlib_metadata.version("intel-tensorflow")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available: if _tf_available:
if version.parse(_tf_version) < version.parse("2"): if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment