Commit 1e47dee2 authored by thomwolf's avatar thomwolf
Browse files
parents c9591f6f 798da627
...@@ -23,8 +23,7 @@ import logging ...@@ -23,8 +23,7 @@ import logging
from pytorch_transformers import is_tf_available from pytorch_transformers import is_tf_available
# if is_tf_available(): if is_tf_available():
if False:
from pytorch_transformers import (AutoConfig, BertConfig, from pytorch_transformers import (AutoConfig, BertConfig,
TFAutoModel, TFBertModel, TFAutoModel, TFBertModel,
TFAutoModelWithLMHead, TFBertForMaskedLM, TFAutoModelWithLMHead, TFBertForMaskedLM,
...@@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -44,7 +43,8 @@ class TFAutoModelTest(unittest.TestCase):
self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
...@@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -55,7 +55,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_lmhead_model_from_pretrained(self): def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
...@@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -66,7 +67,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_sequence_classification_model_from_pretrained(self): def test_sequence_classification_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
...@@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -77,7 +79,8 @@ class TFAutoModelTest(unittest.TestCase):
def test_question_answering_model_from_pretrained(self): def test_question_answering_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name, force_download=True)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
......
...@@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -316,7 +316,8 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment