"...git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "e2d98ddce2d31dcc95c58604d9ddae8232344a1d"
Unverified Commit fb3b22c3 authored by Yuchao Dai's avatar Yuchao Dai Committed by GitHub
Browse files

LlamaTokenizer should be picklable (#24681)

* LlamaTokenizer should be picklable

* make fixup
parent 9a5d468b
...@@ -98,12 +98,13 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -98,12 +98,13 @@ class LlamaTokenizer(PreTrainedTokenizer):
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state return state
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property @property
def vocab_size(self): def vocab_size(self):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import os import os
import pickle
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
...@@ -285,6 +286,13 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -285,6 +286,13 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
padding=False, padding=False,
) )
def test_picklable(self):
with tempfile.NamedTemporaryFile() as f:
shutil.copyfile(SAMPLE_VOCAB, f.name)
tokenizer = LlamaTokenizer(f.name, keep_accents=True)
pickled_tokenizer = pickle.dumps(tokenizer)
pickle.loads(pickled_tokenizer)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment