Unverified Commit e02ed0ee authored by Benjamin Davidson's avatar Benjamin Davidson Committed by GitHub
Browse files

XLMR tokenizer is fully picklable (#13577)

* made tokenizer fully picklable

* remove whitespace

* added testcase
parent af5c6ae5
...@@ -171,6 +171,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -171,6 +171,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state return state
def __setstate__(self, d): def __setstate__(self, d):
...@@ -181,7 +182,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -181,7 +182,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
self.sp_model_kwargs = {} self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
import os import os
import pickle
import shutil
import tempfile
import unittest import unittest
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
...@@ -141,6 +144,13 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -141,6 +144,13 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def big_tokenizer(self): def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base") return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
def test_picklable_without_disk(self):
with tempfile.NamedTemporaryFile() as f:
shutil.copyfile(SAMPLE_VOCAB, f.name)
tokenizer = XLMRobertaTokenizer(f.name, keep_accents=True)
pickled_tokenizer = pickle.dumps(tokenizer)
pickle.loads(pickled_tokenizer)
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment