"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "3fce8881ca0f24e268fac1dc6e85d2b4cbdb0355"
Commit 6e6c8c52 authored by Morgan Funtowicz's avatar Morgan Funtowicz Committed by Lysandre Debut
Browse files

Fix bad handling of env variable USE_TF / USE_TORCH leading to invalid framework being used.


Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 23c6998b
...@@ -4,7 +4,6 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all ...@@ -4,7 +4,6 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all
Copyright by the AllenNLP authors. Copyright by the AllenNLP authors.
""" """
import fnmatch import fnmatch
import json import json
import logging import logging
...@@ -26,32 +25,31 @@ from tqdm.auto import tqdm ...@@ -26,32 +25,31 @@ from tqdm.auto import tqdm
from . import __version__ from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
os.environ.setdefault("USE_TORCH", "YES") if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"): os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else: else:
logger.info("USE_TORCH override through env variable, disabling PyTorch") logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False _torch_available = False
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault("USE_TF", "YES") if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
if os.environ["USE_TF"].upper() in ("1", "ON", "YES"): os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
else: else:
logger.info("USE_TF override through env variable, disabling Tensorflow") logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False _tf_available = False
except (ImportError, AssertionError): except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name _tf_available = False # pylint: disable=invalid-name
...@@ -66,7 +64,6 @@ except ImportError: ...@@ -66,7 +64,6 @@ except ImportError:
) )
default_cache_path = os.path.join(torch_cache_home, "transformers") default_cache_path = os.path.join(torch_cache_home, "transformers")
try: try:
from pathlib import Path from pathlib import Path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment