You need to sign in or sign up before continuing.
Commit faef6f61 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Fix logic order for USE_TF/USE_TORCH

parent 5664327c
...@@ -29,25 +29,27 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name ...@@ -29,25 +29,27 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TF', 'YES') os.environ.setdefault('USE_TF', 'YES')
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'): if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
logger.info("USE_TF override through env variable, disabling Tensorflow")
_tf_available = False
else:
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
logger.info("USE_TF override through env variable, disabling Tensorflow")
_tf_available = False
except (ImportError, AssertionError): except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name _tf_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TORCH', 'YES') os.environ.setdefault('USE_TORCH', 'YES')
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'): if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
logger.info("USE_TORCH override through env variable, disabling PyTorch")
_torch_available = False
else:
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else:
logger.info("USE_TORCH override through env variable, disabling PyTorch")
_torch_available = False
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment