Unverified Commit 5daca95d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2268 from aaugustin/improve-repository-structure

Improve repository structure
parents 54abc67a 00204f2b
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
from .configuration_common_test import ConfigTester from .test_configuration_common import ConfigTester
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -34,7 +34,7 @@ if is_tf_available(): ...@@ -34,7 +34,7 @@ if is_tf_available():
@require_tf @require_tf
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
...@@ -251,7 +251,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -251,7 +251,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in ["xxx-base-uncased"]: for model_name in ["xxx-base-uncased"]:
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .configuration_common_test import ConfigTester from .test_configuration_common import ConfigTester
from .modeling_common_test import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -36,7 +36,7 @@ if is_torch_available(): ...@@ -36,7 +36,7 @@ if is_torch_available():
@require_torch @require_torch
class XxxModelTest(CommonTestCases.CommonModelTester): class XxxModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification) (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
...@@ -272,7 +272,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -272,7 +272,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -20,10 +20,10 @@ from io import open ...@@ -20,10 +20,10 @@ from io import open
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
from .tokenization_tests_commons import CommonTestCases from .test_tokenization_common import TokenizerTesterMixin
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XxxTokenizer tokenizer_class = XxxTokenizer
...@@ -63,7 +63,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -63,7 +63,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokens = tokenizer.tokenize("UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
if __name__ == "__main__":
unittest.main()
...@@ -16,9 +16,8 @@ from __future__ import absolute_import, division, print_function ...@@ -16,9 +16,8 @@ from __future__ import absolute_import, division, print_function
import json import json
import os import os
import unittest
from .tokenization_tests_commons import TemporaryDirectory from .test_tokenization_common import TemporaryDirectory
class ConfigTester(object): class ConfigTester(object):
...@@ -64,7 +63,3 @@ class ConfigTester(object): ...@@ -64,7 +63,3 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase): ...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase):
# ^^ not an error, we test that the # ^^ not an error, we test that the
# second call does not fail. # second call does not fail.
self.assertEqual(HfFolder.get_token(), None) self.assertEqual(HfFolder.get_token(), None)
if __name__ == "__main__":
unittest.main()
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from transformers.modelcard import ModelCard from transformers.modelcard import ModelCard
from .tokenization_tests_commons import TemporaryDirectory from .test_tokenization_common import TemporaryDirectory
class ModelCardTester(unittest.TestCase): class ModelCardTester(unittest.TestCase):
...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase): ...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase):
model_card_second = ModelCard.from_pretrained(tmpdirname) model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .configuration_common_test import ConfigTester from .test_configuration_common import ConfigTester
from .modeling_common_test import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -35,7 +35,7 @@ if is_torch_available(): ...@@ -35,7 +35,7 @@ if is_torch_available():
@require_torch @require_torch
class AlbertModelTest(CommonTestCases.CommonModelTester): class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
...@@ -249,7 +249,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -249,7 +249,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase): ...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment