Commit f116cf59 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Allow hidding frameworks through environment variables (NO_TF, NO_TORCH).

parent 6e61e060
...@@ -27,17 +27,25 @@ from contextlib import contextmanager ...@@ -27,17 +27,25 @@ from contextlib import contextmanager
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
import tensorflow as tf if 'NO_TF' in os.environ and os.environ['NO_TF'].upper() in ('1', 'ON'):
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 logger.info("Found NO_TF, disabling TensorFlow")
_tf_available = True # pylint: disable=invalid-name _tf_available = False
logger.info("TensorFlow version {} available.".format(tf.__version__)) else:
import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError): except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name _tf_available = False # pylint: disable=invalid-name
try: try:
import torch if 'NO_TORCH' in os.environ and os.environ['NO_TORCH'].upper() in ('1', 'ON'):
_torch_available = True # pylint: disable=invalid-name logger.info("Found NO_TORCH, disabling PyTorch")
logger.info("PyTorch version {} available.".format(torch.__version__)) _torch_available = False
else:
import torch
_torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
...@@ -77,6 +85,7 @@ def is_torch_available(): ...@@ -77,6 +85,7 @@ def is_torch_available():
return _torch_available return _torch_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
if not six.PY2: if not six.PY2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment