Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f116cf59
Commit
f116cf59
authored
Dec 09, 2019
by
Morgan Funtowicz
Browse files
Allow hidding frameworks through environment variables (NO_TF, NO_TORCH).
parent
6e61e060
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
7 deletions
+16
-7
transformers/file_utils.py
transformers/file_utils.py
+16
-7
No files found.
transformers/file_utils.py
View file @
f116cf59
...
@@ -27,17 +27,25 @@ from contextlib import contextmanager
...
@@ -27,17 +27,25 @@ from contextlib import contextmanager
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
try
:
import
tensorflow
as
tf
if
'NO_TF'
in
os
.
environ
and
os
.
environ
[
'NO_TF'
].
upper
()
in
(
'1'
,
'ON'
):
assert
hasattr
(
tf
,
'__version__'
)
and
int
(
tf
.
__version__
[
0
])
>=
2
logger
.
info
(
"Found NO_TF, disabling TensorFlow"
)
_tf_available
=
True
# pylint: disable=invalid-name
_tf_available
=
False
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
else
:
import
tensorflow
as
tf
assert
hasattr
(
tf
,
'__version__'
)
and
int
(
tf
.
__version__
[
0
])
>=
2
_tf_available
=
True
# pylint: disable=invalid-name
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
except
(
ImportError
,
AssertionError
):
except
(
ImportError
,
AssertionError
):
_tf_available
=
False
# pylint: disable=invalid-name
_tf_available
=
False
# pylint: disable=invalid-name
try
:
try
:
import
torch
if
'NO_TORCH'
in
os
.
environ
and
os
.
environ
[
'NO_TORCH'
].
upper
()
in
(
'1'
,
'ON'
):
_torch_available
=
True
# pylint: disable=invalid-name
logger
.
info
(
"Found NO_TORCH, disabling PyTorch"
)
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
_torch_available
=
False
else
:
import
torch
_torch_available
=
True
# pylint: disable=invalid-name
logger
.
info
(
"PyTorch version {} available."
.
format
(
torch
.
__version__
))
except
ImportError
:
except
ImportError
:
_torch_available
=
False
# pylint: disable=invalid-name
_torch_available
=
False
# pylint: disable=invalid-name
...
@@ -77,6 +85,7 @@ def is_torch_available():
...
@@ -77,6 +85,7 @@ def is_torch_available():
return
_torch_available
return
_torch_available
def
is_tf_available
():
def
is_tf_available
():
return
_tf_available
return
_tf_available
if
not
six
.
PY2
:
if
not
six
.
PY2
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment