"docs/vscode:/vscode.git/clone" did not exist on "f84d85ba67ea8a812427f8a4f1e77c9d9fc3ebac"
Unverified Commit e0db8276 authored by Philip May's avatar Philip May Committed by GitHub
Browse files

add sp_model_kwargs to unpickle of xlm roberta tok (#11430)

add test for pickle

simplify test

fix test code style

add missing pickle import

fix test

fix test

fix test
parent b43e3f93
......@@ -135,7 +135,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__(
bos_token=bos_token,
......@@ -145,11 +145,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
......@@ -175,7 +175,12 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def build_inputs_with_special_tokens(
......
......@@ -16,6 +16,7 @@
import itertools
import os
import pickle
import unittest
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
......@@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(all_equal)
def test_pickle_subword_regularization_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
# Subword regularization is only available for the slow tokenizer.
sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs)
tokenizer_bin = pickle.dumps(tokenizer)
tokenizer_new = pickle.loads(tokenizer_bin)
self.assertIsNotNone(tokenizer_new.sp_model_kwargs)
self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict))
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment