Commit 36bca545 authored by thomwolf's avatar thomwolf
Browse files

tokenization abstract class - tests for examples

parent a4f98054
This diff is collapsed.
......@@ -16,35 +16,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import unittest
import json
import random
import shutil
import pytest
import argparse
import torch
try:
# python 3.4+ can use builtin unittest.mock instead of mock package
from unittest.mock import patch
except ImportError:
from mock import patch
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
import run_bert_squad as rbs
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument('-f')
args = parser.parse_args()
return args.f
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
class ExamplesTests(unittest.TestCase):
model = BertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, PreTrainedModel)
def test_run_squad(self):
testargs = ["prog", "-f", "/home/test/setup.py"]
with patch.object(sys, 'argv', testargs):
setup = get_setup_file()
assert setup == "/home/test/setup.py"
# rbs.main()
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()
......@@ -5,6 +5,7 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization)
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
......@@ -26,11 +27,10 @@ from .modeling_xlnet import (XLNetConfig,
from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
......@@ -29,7 +29,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__)
......
......@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm
......
......@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .file_utils import cached_path
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm
......
......@@ -37,7 +37,7 @@ from torch.nn.parameter import Parameter
from .modeling_bert import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
logger = logging.getLogger(__name__)
......
......@@ -598,9 +598,3 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
......@@ -35,7 +35,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
prune_linear_layer, SequenceSummary, SQuADHead)
logger = logging.getLogger(__name__)
......
......@@ -32,7 +32,7 @@ from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
......
......@@ -26,7 +26,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class BertModelTest(unittest.TestCase):
......
......@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase):
......
......@@ -24,7 +24,7 @@ import torch
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase):
......
......@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object):
......
......@@ -16,17 +16,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
......
......@@ -23,7 +23,7 @@ import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class XLMModelTest(unittest.TestCase):
......
......@@ -28,7 +28,7 @@ import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object):
......
......@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
_is_whitespace)
from .tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -49,14 +49,6 @@ class TokenizationTest(unittest.TestCase):
os.remove(vocab_file)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_chinese(self):
tokenizer = BasicTokenizer()
......
......@@ -17,10 +17,8 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import json
import shutil
import pytest
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
from .tokenization_tests_commons import create_and_check_tokenizer_commons
......@@ -56,13 +54,6 @@ class GPT2TokenizationTest(unittest.TestCase):
os.remove(vocab_file)
os.remove(merges_file)
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment