Commit 983c484f authored by Branden Chan's avatar Branden Chan Committed by Lysandre Debut
Browse files

add __getstate__ and __setstate__ to XLMRobertaTokenizer

parent cefd51c5
......@@ -19,8 +19,6 @@ import logging
import os
from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
......@@ -87,6 +85,16 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
......@@ -106,6 +114,24 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
try:
import sentencepiece as spm
except ImportError:
logger.warning(
"You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
raise
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment