Unverified Commit 5daca95d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2268 from aaugustin/improve-repository-structure

Improve repository structure
parents 54abc67a 00204f2b
......@@ -18,8 +18,8 @@ import unittest
from transformers import XxxConfig, is_tf_available
from .configuration_common_test import ConfigTester
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow
......@@ -34,7 +34,7 @@ if is_tf_available():
@require_tf
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
......@@ -251,7 +251,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in ["xxx-base-uncased"]:
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -18,8 +18,8 @@ import unittest
from transformers import is_torch_available
from .configuration_common_test import ConfigTester
from .modeling_common_test import CommonTestCases, ids_tensor
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device
......@@ -36,7 +36,7 @@ if is_torch_available():
@require_torch
class XxxModelTest(CommonTestCases.CommonModelTester):
class XxxModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
......@@ -272,7 +272,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -20,10 +20,10 @@ from io import open
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
from .tokenization_tests_commons import CommonTestCases
from .test_tokenization_common import TokenizerTesterMixin
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = XxxTokenizer
......@@ -63,7 +63,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
if __name__ == "__main__":
unittest.main()
......@@ -16,9 +16,8 @@ from __future__ import absolute_import, division, print_function
import json
import os
import unittest
from .tokenization_tests_commons import TemporaryDirectory
from .test_tokenization_common import TemporaryDirectory
class ConfigTester(object):
......@@ -64,7 +63,3 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
......@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase):
# ^^ not an error, we test that the
# second call does not fail.
self.assertEqual(HfFolder.get_token(), None)
if __name__ == "__main__":
unittest.main()
......@@ -20,7 +20,7 @@ import unittest
from transformers.modelcard import ModelCard
from .tokenization_tests_commons import TemporaryDirectory
from .test_tokenization_common import TemporaryDirectory
class ModelCardTester(unittest.TestCase):
......@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase):
model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
......@@ -18,8 +18,8 @@ import unittest
from transformers import is_torch_available
from .configuration_common_test import ConfigTester
from .modeling_common_test import CommonTestCases, ids_tensor
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device
......@@ -35,7 +35,7 @@ if is_torch_available():
@require_torch
class AlbertModelTest(CommonTestCases.CommonModelTester):
class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
......@@ -249,7 +249,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
......@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment