Commit 36bca545 authored by thomwolf's avatar thomwolf
Browse files

tokenization abstract class - tests for examples

parent a4f98054
This diff is collapsed.
...@@ -16,35 +16,33 @@ from __future__ import absolute_import ...@@ -16,35 +16,33 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import sys
import unittest import unittest
import json import argparse
import random
import shutil
import pytest
import torch try:
# python 3.4+ can use builtin unittest.mock instead of mock package
from unittest.mock import patch
except ImportError:
from mock import patch
from pytorch_transformers import PretrainedConfig, PreTrainedModel import run_bert_squad as rbs
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument('-f')
args = parser.parse_args()
return args.f
class ModelUtilsTest(unittest.TestCase): class ExamplesTests(unittest.TestCase):
def test_model_from_pretrained(self):
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
model = BertModel.from_pretrained(model_name) def test_run_squad(self):
self.assertIsNotNone(model) testargs = ["prog", "-f", "/home/test/setup.py"]
self.assertIsInstance(model, PreTrainedModel) with patch.object(sys, 'argv', testargs):
setup = get_setup_file()
assert setup == "/home/test/setup.py"
# rbs.main()
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -5,6 +5,7 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) ...@@ -5,6 +5,7 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization)
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
...@@ -26,11 +27,10 @@ from .modeling_xlnet import (XLNetConfig, ...@@ -26,11 +27,10 @@ from .modeling_xlnet import (XLNetConfig,
from .modeling_xlm import (XLMConfig, XLMModel, from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification, XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering) XLMForQuestionAnswering)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary) PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
......
...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary) PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
......
...@@ -37,7 +37,7 @@ from torch.nn.parameter import Parameter ...@@ -37,7 +37,7 @@ from torch.nn.parameter import Parameter
from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_bert import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -598,9 +598,3 @@ def prune_layer(layer, index, dim=None): ...@@ -598,9 +598,3 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else: else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__)) raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
...@@ -35,7 +35,7 @@ from torch.nn import functional as F ...@@ -35,7 +35,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead) prune_linear_layer, SequenceSummary, SQuADHead)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -32,7 +32,7 @@ from torch.nn import functional as F ...@@ -32,7 +32,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits) SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
......
...@@ -26,7 +26,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, ...@@ -26,7 +26,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification, BertForMultipleChoice) BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class BertModelTest(unittest.TestCase): class BertModelTest(unittest.TestCase):
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (GPT2Config, GPT2Model, from pytorch_transformers import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel) GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase): class GPT2ModelTest(unittest.TestCase):
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase): class OpenAIModelTest(unittest.TestCase):
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class TransfoXLModelTest(unittest.TestCase): class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object): class TransfoXLModelTester(object):
......
...@@ -16,17 +16,10 @@ from __future__ import absolute_import ...@@ -16,17 +16,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_transformers import PretrainedConfig, PreTrainedModel from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
......
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class XLMModelTest(unittest.TestCase): class XLMModelTest(unittest.TestCase):
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class XLNetModelTest(unittest.TestCase): class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object): class XLNetModelTester(object):
......
...@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, ...@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP) _is_whitespace)
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
...@@ -49,14 +49,6 @@ class TokenizationTest(unittest.TestCase): ...@@ -49,14 +49,6 @@ class TokenizationTest(unittest.TestCase):
os.remove(vocab_file) os.remove(vocab_file)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -17,10 +17,8 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,10 +17,8 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil
import pytest
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
...@@ -56,13 +54,6 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -56,13 +54,6 @@ class GPT2TokenizationTest(unittest.TestCase):
os.remove(vocab_file) os.remove(vocab_file)
os.remove(merges_file) os.remove(merges_file)
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment