Commit a6bcfb80 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent 78863f6b
...@@ -56,8 +56,6 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO ...@@ -56,8 +56,6 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
# Modeling # Modeling
if is_torch_available(): if is_torch_available():
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead) AutoModelWithLMHead)
...@@ -96,8 +94,6 @@ if is_torch_available(): ...@@ -96,8 +94,6 @@ if is_torch_available():
# TensorFlow # TensorFlow
if is_tf_available(): if is_tf_available():
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead) TFAutoModelWithLMHead)
......
...@@ -23,16 +23,20 @@ from botocore.exceptions import ClientError ...@@ -23,16 +23,20 @@ from botocore.exceptions import ClientError
import requests import requests
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
import tensorflow as tf import tensorflow as tf
assert int(tf.__version__[0]) >= 2 assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
except (ImportError, AssertionError): except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name _tf_available = False # pylint: disable=invalid-name
try: try:
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
...@@ -67,8 +71,6 @@ TF2_WEIGHTS_NAME = 'tf_model.h5' ...@@ -67,8 +71,6 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
......
...@@ -27,7 +27,7 @@ from .file_utils import cached_path, is_tf_available, is_torch_available ...@@ -27,7 +27,7 @@ from .file_utils import cached_path, is_tf_available, is_torch_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
if is_torch_available() if is_torch_available():
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment