Commit 4f8b5f68 authored by thomwolf's avatar thomwolf
Browse files

add fix for serialization of tokenizer

parent d9184620
...@@ -182,6 +182,21 @@ class XLNetTokenizer(object): ...@@ -182,6 +182,21 @@ class XLNetTokenizer(object):
def __len__(self): def __len__(self):
return len(self.encoder) + len(self.special_tokens) return len(self.encoder) + len(self.special_tokens)
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def set_special_tokens(self, special_tokens): def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder. """ Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the The additional tokens are indexed starting from the last index of the
......
...@@ -15,11 +15,17 @@ ...@@ -15,11 +15,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import sys
import unittest import unittest
from io import open from io import open
import shutil import shutil
import pytest import pytest
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer, from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
PRETRAINED_VOCAB_ARCHIVE_MAP, PRETRAINED_VOCAB_ARCHIVE_MAP,
SPIECE_UNDERLINE) SPIECE_UNDERLINE)
...@@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -43,8 +49,6 @@ class XLNetTokenizationTest(unittest.TestCase):
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path) vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path, tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True) keep_accents=True)
os.remove(vocab_file)
os.remove(special_tokens_file)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
...@@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -65,6 +69,22 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.']) u'<unk>', u'.'])
text = "Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
subwords = tokenizer.tokenize(text)
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
self.assertListEqual(subwords, subwords_loaded)
os.remove(filename)
os.remove(vocab_file)
os.remove(special_tokens_file)
@pytest.mark.slow @pytest.mark.slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/" cache_dir = "/tmp/pytorch_pretrained_bert_test/"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment