Commit c079d7dd authored by thomwolf's avatar thomwolf
Browse files

fix python 2 tests

parent b1978698
...@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, ...@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace, VOCAB_FILES_NAMES) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TokenizationTest(unittest.TestCase): class TokenizationTest(unittest.TestCase):
...@@ -33,21 +33,18 @@ class TokenizationTest(unittest.TestCase): ...@@ -33,21 +33,18 @@ class TokenizationTest(unittest.TestCase):
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ",", "low", "lowest", "##ing", ",", "low", "lowest",
] ]
vocab_directory = "/tmp/" with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory) create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
os.remove(vocab_file)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import tempfile
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class GPT2TokenizationTest(unittest.TestCase): class GPT2TokenizationTest(unittest.TestCase):
...@@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase):
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
special_tokens_map = {"unk_token": "<unk>"} special_tokens_map = {"unk_token": "<unk>"}
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp: with open(vocab_file, "w") as fp:
......
...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import tempfile
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class OpenAIGPTTokenizationTest(unittest.TestCase): class OpenAIGPTTokenizationTest(unittest.TestCase):
...@@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp: with open(vocab_file, "w") as fp:
......
...@@ -14,18 +14,25 @@ ...@@ -14,18 +14,25 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys import sys
from io import open from io import open
import tempfile import tempfile
import shutil
if sys.version_info[0] == 3:
unicode = str
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else: else:
import pickle import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
...@@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, * ...@@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname)
......
...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import tempfile
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(unittest.TestCase):
...@@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l", "running", ",", "low", "l",
] ]
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
......
...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import tempfile
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class XLMTokenizationTest(unittest.TestCase): class XLMTokenizationTest(unittest.TestCase):
...@@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase): ...@@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp: with open(vocab_file, "w") as fp:
......
...@@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import tempfile
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES) from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from.tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model') 'fixtures/test_sentencepiece.model')
...@@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
with tempfile.TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname) create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
......
...@@ -231,8 +231,7 @@ class PreTrainedTokenizer(object): ...@@ -231,8 +231,7 @@ class PreTrainedTokenizer(object):
# Add supplementary tokens. # Add supplementary tokens.
if added_tokens_file is not None: if added_tokens_file is not None:
added_tokens = json.load(open(added_tokens_file, encoding="utf-8")) added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder) tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder) tokenizer.added_tokens_decoder.update(added_tok_decoder)
...@@ -256,7 +255,11 @@ class PreTrainedTokenizer(object): ...@@ -256,7 +255,11 @@ class PreTrainedTokenizer(object):
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, 'w', encoding='utf-8') as f: with open(added_tokens_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False)) if self.added_tokens_encoder:
out_str = json.dumps(self.added_tokens_decoder, ensure_ascii=False)
else:
out_str = u"{}"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory) vocab_files = self.save_vocabulary(save_directory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment