"tests/vscode:/vscode.git/clone" did not exist on "fa3c86beaf04e297d4b0e824692e3bd4edfb5f22"
Commit 7e98e211 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Remove unittest.main() in test modules.

This construct isn't used anymore these days.

Running python tests/test_foo.py puts the tests/ directory on
PYTHONPATH, which isn't representative of how we run tests.

Use python -m unittest tests/test_foo.py instead.
parent 6be7cdda
...@@ -98,7 +98,3 @@ class SummarizationDataProcessingTest(unittest.TestCase): ...@@ -98,7 +98,3 @@ class SummarizationDataProcessingTest(unittest.TestCase):
result = compute_token_type_ids(batch, separator) result = compute_token_type_ids(batch, separator)
np.testing.assert_array_equal(result, expected) np.testing.assert_array_equal(result, expected)
if __name__ == "__main__":
unittest.main()
...@@ -104,7 +104,3 @@ class ExamplesTests(unittest.TestCase): ...@@ -104,7 +104,3 @@ class ExamplesTests(unittest.TestCase):
with patch.object(sys, "argv", testargs + [model_type, model_name]): with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main() result = run_generation.main()
self.assertGreaterEqual(len(result), 10) self.assertGreaterEqual(len(result), 10)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -251,7 +249,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -251,7 +249,3 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in ["xxx-base-uncased"]: for model_name in ["xxx-base-uncased"]:
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -272,7 +270,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -272,7 +270,3 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import unittest
from io import open from io import open
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
...@@ -63,7 +62,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -63,7 +62,3 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokens = tokenizer.tokenize("UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
if __name__ == "__main__":
unittest.main()
...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function ...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function
import json import json
import os import os
import unittest
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
...@@ -64,7 +63,3 @@ class ConfigTester(object): ...@@ -64,7 +63,3 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()
...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase): ...@@ -102,7 +102,3 @@ class HfFolderTest(unittest.TestCase):
# ^^ not an error, we test that the # ^^ not an error, we test that the
# second call does not fail. # second call does not fail.
self.assertEqual(HfFolder.get_token(), None) self.assertEqual(HfFolder.get_token(), None)
if __name__ == "__main__":
unittest.main()
...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase): ...@@ -80,7 +80,3 @@ class ModelCardTester(unittest.TestCase):
model_card_second = ModelCard.from_pretrained(tmpdirname) model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -249,7 +247,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -249,7 +247,3 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase): ...@@ -100,7 +100,3 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -475,7 +473,3 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -475,7 +473,3 @@ class BertModelTest(CommonTestCases.CommonModelTester):
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -892,7 +892,3 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -892,7 +892,3 @@ class ModelUtilsTest(unittest.TestCase):
self.assertEqual(model.config.output_attentions, True) self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config) self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -211,7 +209,3 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): ...@@ -211,7 +209,3 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -250,7 +248,3 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -250,7 +248,3 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -48,7 +48,3 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -48,7 +48,3 @@ class EncoderDecoderModelTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = Model2Model.from_pretrained("does-not-exist") _ = Model2Model.from_pretrained("does-not-exist")
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -248,7 +246,3 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): ...@@ -248,7 +246,3 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR) model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -205,7 +203,3 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): ...@@ -205,7 +203,3 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -298,7 +298,3 @@ class RobertaModelIntegrationTest(unittest.TestCase): ...@@ -298,7 +298,3 @@ class RobertaModelIntegrationTest(unittest.TestCase):
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]]) expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-3)) self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-3))
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -212,7 +210,3 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -212,7 +210,3 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR) model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -213,7 +211,3 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -213,7 +211,3 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment