Commit 1e47dee2 authored by thomwolf's avatar thomwolf
Browse files
parents c9591f6f 798da627
......@@ -23,8 +23,7 @@ import logging
from pytorch_transformers import is_tf_available
# if is_tf_available():
if False:
if is_tf_available():
from pytorch_transformers import (AutoConfig, BertConfig,
TFAutoModel, TFBertModel,
TFAutoModelWithLMHead, TFBertForMaskedLM,
......@@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase):
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
......@@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
......@@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
......@@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
......
......@@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment