Commit 35401fe5 authored by Aymeric Augustin's avatar Aymeric Augustin Committed by Julien Chaumond
Browse files

Remove dependency on pytest for running tests (#2055)

* Switch to plain unittest for skipping slow tests.

Add a RUN_SLOW environment variable for running them.

* Switch to plain unittest for PyTorch dependency.

* Switch to plain unittest for TensorFlow dependency.

* Avoid leaking open files in the test suite.

This prevents spurious warnings when running tests.

* Fix unicode warning on Python 2 when running tests.

The warning was:

    UnicodeWarning: Unicode equal comparison failed to convert both arguments to Unicode - interpreting them as being unequal

* Support running PyTorch tests on a GPU.

Reverts 27e015bd.

* Tests no longer require pytest.

* Make tests pass on cuda
parent e4679cdd
...@@ -25,18 +25,17 @@ import unittest ...@@ -25,18 +25,17 @@ import unittest
import uuid import uuid
import tempfile import tempfile
import pytest
import sys import sys
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
from .utils import require_tf, slow
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from transformers import TFPreTrainedModel from transformers import TFPreTrainedModel
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP # from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -62,6 +61,7 @@ def _config_zero_init(config): ...@@ -62,6 +61,7 @@ def _config_zero_init(config):
class TFCommonTestCases: class TFCommonTestCases:
@require_tf
class TFCommonModelTester(unittest.TestCase): class TFCommonModelTester(unittest.TestCase):
model_tester = None model_tester = None
......
...@@ -18,11 +18,11 @@ from __future__ import print_function ...@@ -18,11 +18,11 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import CTRLConfig, is_tf_available from transformers import CTRLConfig, is_tf_available
...@@ -30,10 +30,9 @@ if is_tf_available(): ...@@ -30,10 +30,9 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_ctrl import (TFCTRLModel, TFCTRLLMHeadModel, from transformers.modeling_tf_ctrl import (TFCTRLModel, TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else () all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
...@@ -188,7 +187,7 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -188,7 +187,7 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs) self.model_tester.create_and_check_ctrl_lm_head(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -17,10 +17,10 @@ from __future__ import division ...@@ -17,10 +17,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import DistilBertConfig, is_tf_available from transformers import DistilBertConfig, is_tf_available
...@@ -30,10 +30,9 @@ if is_tf_available(): ...@@ -30,10 +30,9 @@ if is_tf_available():
TFDistilBertForMaskedLM, TFDistilBertForMaskedLM,
TFDistilBertForQuestionAnswering, TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification) TFDistilBertForSequenceClassification)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, all_model_classes = (TFDistilBertModel, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering,
...@@ -210,7 +209,7 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -210,7 +209,7 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
# @pytest.mark.slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/" # cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,11 +18,11 @@ from __future__ import print_function ...@@ -18,11 +18,11 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
...@@ -31,10 +31,9 @@ if is_tf_available(): ...@@ -31,10 +31,9 @@ if is_tf_available():
from transformers.modeling_tf_gpt2 import (TFGPT2Model, TFGPT2LMHeadModel, from transformers.modeling_tf_gpt2 import (TFGPT2Model, TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel, TFGPT2DoubleHeadsModel,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
...@@ -219,7 +218,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -219,7 +218,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs) self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,11 +18,11 @@ from __future__ import print_function ...@@ -18,11 +18,11 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import OpenAIGPTConfig, is_tf_available from transformers import OpenAIGPTConfig, is_tf_available
...@@ -31,10 +31,9 @@ if is_tf_available(): ...@@ -31,10 +31,9 @@ if is_tf_available():
from transformers.modeling_tf_openai import (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, from transformers.modeling_tf_openai import (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel,
TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTDoubleHeadsModel,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, all_model_classes = (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel,
...@@ -218,7 +217,7 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -218,7 +217,7 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_openai_gpt_double_head(*config_and_inputs) self.model_tester.create_and_check_openai_gpt_double_head(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,10 +18,10 @@ from __future__ import print_function ...@@ -18,10 +18,10 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import RobertaConfig, is_tf_available from transformers import RobertaConfig, is_tf_available
...@@ -32,10 +32,9 @@ if is_tf_available(): ...@@ -32,10 +32,9 @@ if is_tf_available():
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFRobertaModel,TFRobertaForMaskedLM, all_model_classes = (TFRobertaModel,TFRobertaForMaskedLM,
...@@ -191,7 +190,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -191,7 +190,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
...@@ -203,7 +202,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -203,7 +202,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
class TFRobertaModelIntegrationTest(unittest.TestCase): class TFRobertaModelIntegrationTest(unittest.TestCase):
@pytest.mark.slow @slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
model = TFRobertaForMaskedLM.from_pretrained('roberta-base') model = TFRobertaForMaskedLM.from_pretrained('roberta-base')
...@@ -224,7 +223,7 @@ class TFRobertaModelIntegrationTest(unittest.TestCase): ...@@ -224,7 +223,7 @@ class TFRobertaModelIntegrationTest(unittest.TestCase):
numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-3) numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-3)
) )
@pytest.mark.slow @slow
def test_inference_no_head(self): def test_inference_no_head(self):
model = TFRobertaModel.from_pretrained('roberta-base') model = TFRobertaModel.from_pretrained('roberta-base')
...@@ -240,7 +239,7 @@ class TFRobertaModelIntegrationTest(unittest.TestCase): ...@@ -240,7 +239,7 @@ class TFRobertaModelIntegrationTest(unittest.TestCase):
numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-3) numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-3)
) )
@pytest.mark.slow @slow
def test_inference_classification_head(self): def test_inference_classification_head(self):
model = TFRobertaForSequenceClassification.from_pretrained('roberta-large-mnli') model = TFRobertaForSequenceClassification.from_pretrained('roberta-large-mnli')
......
...@@ -19,10 +19,10 @@ from __future__ import print_function ...@@ -19,10 +19,10 @@ from __future__ import print_function
import unittest import unittest
import random import random
import shutil import shutil
import pytest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from transformers import TransfoXLConfig, is_tf_available from transformers import TransfoXLConfig, is_tf_available
...@@ -31,10 +31,9 @@ if is_tf_available(): ...@@ -31,10 +31,9 @@ if is_tf_available():
from transformers.modeling_tf_transfo_xl import (TFTransfoXLModel, from transformers.modeling_tf_transfo_xl import (TFTransfoXLModel,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
@require_tf
class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else () all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
...@@ -204,7 +203,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -204,7 +203,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs) self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
from transformers import is_tf_available from transformers import is_tf_available
...@@ -29,13 +28,13 @@ if is_tf_available(): ...@@ -29,13 +28,13 @@ if is_tf_available():
TFXLMForSequenceClassification, TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
@require_tf
class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFXLMModel, TFXLMWithLMHeadModel, all_model_classes = (TFXLMModel, TFXLMWithLMHeadModel,
...@@ -251,7 +250,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -251,7 +250,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs) self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -21,7 +21,6 @@ import unittest ...@@ -21,7 +21,6 @@ import unittest
import json import json
import random import random
import shutil import shutil
import pytest
from transformers import XLNetConfig, is_tf_available from transformers import XLNetConfig, is_tf_available
...@@ -33,12 +32,13 @@ if is_tf_available(): ...@@ -33,12 +32,13 @@ if is_tf_available():
TFXLNetForTokenClassification, TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple, TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
@require_tf
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel, all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
...@@ -320,7 +320,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -320,7 +320,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs) self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -19,7 +19,6 @@ from __future__ import print_function ...@@ -19,7 +19,6 @@ from __future__ import print_function
import unittest import unittest
import random import random
import shutil import shutil
import pytest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,12 +26,13 @@ if is_torch_available(): ...@@ -27,12 +26,13 @@ if is_torch_available():
import torch import torch
from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
@require_torch
class TransfoXLModelTest(CommonTestCases.CommonModelTester): class TransfoXLModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
...@@ -111,6 +111,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -111,6 +111,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels): def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLModel(config) model = TransfoXLModel(config)
model.to(torch_device)
model.eval() model.eval()
hidden_states_1, mems_1 = model(input_ids_1) hidden_states_1, mems_1 = model(input_ids_1)
...@@ -140,6 +141,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -140,6 +141,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLLMHeadModel(config) model = TransfoXLLMHeadModel(config)
model.to(torch_device)
model.eval() model.eval()
lm_logits_1, mems_1 = model(input_ids_1) lm_logits_1, mems_1 = model(input_ids_1)
...@@ -204,7 +206,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -204,7 +206,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result) self.model_tester.check_transfo_xl_lm_head_output(output_result)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -26,13 +25,13 @@ if is_torch_available(): ...@@ -26,13 +25,13 @@ if is_torch_available():
from transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, from transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
@require_torch
class XLMModelTest(CommonTestCases.CommonModelTester): class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
...@@ -148,6 +147,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -148,6 +147,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMModel(config=config) model = XLMModel(config=config)
model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids) outputs = model(input_ids, langs=token_type_ids)
...@@ -163,6 +163,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -163,6 +163,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMWithLMHeadModel(config) model = XLMWithLMHeadModel(config)
model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
...@@ -182,6 +183,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -182,6 +183,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnsweringSimple(config) model = XLMForQuestionAnsweringSimple(config)
model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) outputs = model(input_ids)
...@@ -206,6 +208,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -206,6 +208,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnswering(config) model = XLMForQuestionAnswering(config)
model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) outputs = model(input_ids)
...@@ -260,6 +263,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -260,6 +263,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForSequenceClassification(config) model = XLMForSequenceClassification(config)
model.to(torch_device)
model.eval() model.eval()
(logits,) = model(input_ids) (logits,) = model(input_ids)
...@@ -312,7 +316,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -312,7 +316,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs) self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -21,7 +21,6 @@ import unittest ...@@ -21,7 +21,6 @@ import unittest
import json import json
import random import random
import shutil import shutil
import pytest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -31,12 +30,13 @@ if is_torch_available(): ...@@ -31,12 +30,13 @@ if is_torch_available():
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification,
XLNetForTokenClassification, XLNetForQuestionAnswering) XLNetForTokenClassification, XLNetForQuestionAnswering)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
@require_torch
class XLNetModelTest(CommonTestCases.CommonModelTester): class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification, all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification,
...@@ -100,9 +100,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -100,9 +100,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float() input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float) perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float) target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device)
target_mapping[:, 0, -1] = 1.0 # predict last token target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None sequence_labels = None
...@@ -141,6 +141,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -141,6 +141,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config) model = XLNetModel(config)
model.to(torch_device)
model.eval() model.eval()
_, _ = model(input_ids_1, input_mask=input_mask) _, _ = model(input_ids_1, input_mask=input_mask)
...@@ -155,6 +156,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -155,6 +156,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config.mem_len = 0 config.mem_len = 0
model = XLNetModel(config) model = XLNetModel(config)
model.to(torch_device)
model.eval() model.eval()
no_mems_outputs = model(input_ids_1) no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)
...@@ -169,6 +171,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -169,6 +171,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config) model = XLNetModel(config)
model.to(torch_device)
model.eval() model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping) _, _, attentions = model(input_ids_1, target_mapping=target_mapping)
...@@ -181,6 +184,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -181,6 +184,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.to(torch_device)
model.eval() model.eval()
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
...@@ -221,6 +225,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -221,6 +225,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids_1) outputs = model(input_ids_1)
...@@ -279,6 +284,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -279,6 +284,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForTokenClassification(config) model = XLNetForTokenClassification(config)
model.to(torch_device)
model.eval() model.eval()
logits, mems_1 = model(input_ids_1) logits, mems_1 = model(input_ids_1)
...@@ -311,6 +317,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -311,6 +317,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForSequenceClassification(config) model = XLNetForSequenceClassification(config)
model.to(torch_device)
model.eval() model.eval()
logits, mems_1 = model(input_ids_1) logits, mems_1 = model(input_ids_1)
...@@ -379,7 +386,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -379,7 +386,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs) self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
@pytest.mark.slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest import unittest
import os import os
import pytest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -31,10 +30,9 @@ if is_torch_available(): ...@@ -31,10 +30,9 @@ if is_torch_available():
get_cosine_schedule_with_warmup, get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup) get_linear_schedule_with_warmup)
else:
pytestmark = pytest.mark.skip("Require Torch")
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
from .utils import require_torch
def unwrap_schedule(scheduler, num_steps=10): def unwrap_schedule(scheduler, num_steps=10):
...@@ -58,6 +56,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10): ...@@ -58,6 +56,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
scheduler.load_state_dict(state_dict) scheduler.load_state_dict(state_dict)
return lrs return lrs
@require_torch
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
...@@ -80,6 +79,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -80,6 +79,7 @@ class OptimizationTest(unittest.TestCase):
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
@require_torch
class ScheduleInitTest(unittest.TestCase): class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50) if is_torch_available() else None m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
......
...@@ -18,15 +18,16 @@ from __future__ import print_function ...@@ -18,15 +18,16 @@ from __future__ import print_function
import unittest import unittest
import shutil import shutil
import pytest
import logging import logging
from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .utils import slow
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
@pytest.mark.slow @slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
......
...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_bert import (BasicTokenizer, from transformers.tokenization_bert import (BasicTokenizer,
...@@ -26,6 +25,7 @@ from transformers.tokenization_bert import (BasicTokenizer, ...@@ -26,6 +25,7 @@ from transformers.tokenization_bert import (BasicTokenizer,
_is_whitespace, VOCAB_FILES_NAMES) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow
class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u"A")) self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" ")) self.assertFalse(_is_punctuation(u" "))
@pytest.mark.slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
......
...@@ -16,13 +16,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,13 +16,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_distilbert import (DistilBertTokenizer) from transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest from .tokenization_bert_test import BertTokenizationTest
from .utils import slow
class DistilBertTokenizationTest(BertTokenizationTest): class DistilBertTokenizationTest(BertTokenizationTest):
...@@ -31,7 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -31,7 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
@pytest.mark.slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
......
...@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import json import json
import unittest import unittest
import pytest
from io import open from io import open
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow
class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -79,7 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -79,7 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
) )
@pytest.mark.slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
......
...@@ -102,9 +102,11 @@ class CommonTestCases: ...@@ -102,9 +102,11 @@ class CommonTestCases:
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"tokenizer.bin") filename = os.path.join(tmpdirname, u"tokenizer.bin")
pickle.dump(tokenizer, open(filename, "wb")) with open(filename, "wb") as handle:
pickle.dump(tokenizer, handle)
tokenizer_new = pickle.load(open(filename, "rb")) with open(filename, "rb") as handle:
tokenizer_new = pickle.load(handle)
subwords_loaded = tokenizer_new.tokenize(text) subwords_loaded = tokenizer_new.tokenize(text)
......
...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from transformers import is_torch_available from transformers import is_torch_available
...@@ -24,11 +23,12 @@ from transformers import is_torch_available ...@@ -24,11 +23,12 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
else:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import require_torch
@require_torch
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer if is_torch_available() else None tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
......
...@@ -18,13 +18,14 @@ from __future__ import print_function ...@@ -18,13 +18,14 @@ from __future__ import print_function
import unittest import unittest
import six import six
import pytest
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
@pytest.mark.slow
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys()) s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]: for model_name in s3_models[:1]:
...@@ -41,6 +42,7 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -41,6 +42,7 @@ class TokenizerUtilsTest(unittest.TestCase):
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int) self.assertIsInstance(special_tok_id, int)
@slow
def test_pretrained_tokenizers(self): def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer) self.check_tokenizer_from_pretrained(GPT2Tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment