Unverified Commit e0db8276 authored by Philip May's avatar Philip May Committed by GitHub
Browse files

add sp_model_kwargs to unpickle of xlm roberta tok (#11430)

add test for pickle

simplify test

fix test code style

add missing pickle import

fix test

fix test

fix test
parent b43e3f93
...@@ -135,7 +135,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -135,7 +135,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
# Mask token behave like a normal word, i.e. include the space before it # Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__( super().__init__(
bos_token=bos_token, bos_token=bos_token,
...@@ -145,11 +145,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -145,11 +145,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
cls_token=cls_token, cls_token=cls_token,
pad_token=pad_token, pad_token=pad_token,
mask_token=mask_token, mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs, sp_model_kwargs=self.sp_model_kwargs,
**kwargs, **kwargs,
) )
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file)) self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file self.vocab_file = vocab_file
...@@ -175,7 +175,12 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -175,7 +175,12 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import itertools import itertools
import os import os
import pickle
import unittest import unittest
from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast from transformers import SPIECE_UNDERLINE, XLMRobertaTokenizer, XLMRobertaTokenizerFast
...@@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -142,6 +143,18 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertFalse(all_equal) self.assertFalse(all_equal)
def test_pickle_subword_regularization_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
# Subword regularization is only available for the slow tokenizer.
sp_model_kwargs = {"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs=sp_model_kwargs)
tokenizer_bin = pickle.dumps(tokenizer)
tokenizer_new = pickle.loads(tokenizer_bin)
self.assertIsNotNone(tokenizer_new.sp_model_kwargs)
self.assertTrue(isinstance(tokenizer_new.sp_model_kwargs, dict))
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
@cached_property @cached_property
def big_tokenizer(self): def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base") return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment