Commit 7e98e211 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Remove unittest.main() in test modules.

This construct isn't used anymore these days.

Running python tests/test_foo.py puts the tests/ directory on
PYTHONPATH, which isn't representative of how we run tests.

Use python -m unittest tests/test_foo.py instead.
parent 6be7cdda
...@@ -98,7 +98,3 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -98,7 +98,3 @@ class SummarizationDataProcessingTest(unittest.TestCase):
result = compute_token_type_ids(batch, separator) result = compute_token_type_ids(batch, separator)
np.testing.assert_array_equal(result, expected) np.testing.assert_array_equal(result, expected)
if __name__ == "__main__":
unittest.main()
...@@ -104,7 +104,3 @@ class ExamplesTests(unittest.TestCase): ...@@ -104,7 +104,3 @@ class ExamplesTests(unittest.TestCase):
with patch.object(sys, "argv", testargs + [model_type, model_name]): with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main() result = run_generation.main()
self.assertGreaterEqual(len(result), 10) self.assertGreaterEqual(len(result), 10)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -251,7 +249,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -251,7 +249,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in ["xxx-base-uncased"]: for model_name in ["xxx-base-uncased"]:
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -272,7 +270,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -272,7 +270,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
...@@ -63,7 +62,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -63,7 +62,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokens = tokenizer.tokenize("UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
if __name__ == "__main__":
unittest.main()
...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function ...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function
import json import json
import os import os
import unittest
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
...@@ -64,7 +63,3 @@ class ConfigTester(object): ...@@ -64,7 +63,3 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase): ...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase):
# ^^ not an error, we test that the # ^^ not an error, we test that the
# second call does not fail. # second call does not fail.
self.assertEqual(HfFolder.get_token(), None) self.assertEqual(HfFolder.get_token(), None)
if __name__ == "__main__":
unittest.main()
...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase): ...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase):
model_card_second = ModelCard.from_pretrained(tmpdirname) model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -249,7 +247,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -249,7 +247,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase): ...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -475,7 +473,3 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -475,7 +473,3 @@ class BertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -892,7 +892,3 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -892,7 +892,3 @@ class ModelUtilsTest(unittest.TestCase):
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config) self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -211,7 +209,3 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): ...@@ -211,7 +209,3 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -250,7 +248,3 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -250,7 +248,3 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -48,7 +48,3 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -48,7 +48,3 @@ class EncoderDecoderModelTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Model2Model.from_pretrained("does-not-exist") _ = Model2Model.from_pretrained("does-not-exist")
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -248,7 +246,3 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): ...@@ -248,7 +246,3 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR) model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -205,7 +203,3 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): ...@@ -205,7 +203,3 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -298,7 +298,3 @@ class RobertaModelIntegrationTest(unittest.TestCase): ...@@ -298,7 +298,3 @@ class RobertaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]]) expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-3)) self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-3))
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -212,7 +210,3 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -212,7 +210,3 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR) model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -213,7 +211,3 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -213,7 +211,3 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment