Commit e37ca8e1 authored by thomwolf's avatar thomwolf
Browse files

fix camembert and XLM-R tokenizer

parent ceae85ad
...@@ -22,6 +22,7 @@ from shutil import copyfile ...@@ -22,6 +22,7 @@ from shutil import copyfile
import sentencepiece as spm import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -145,6 +146,11 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -145,6 +146,11 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory. to a directory.
......
...@@ -22,6 +22,7 @@ from shutil import copyfile ...@@ -22,6 +22,7 @@ from shutil import copyfile
import sentencepiece as spm import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -161,6 +162,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -161,6 +162,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory. to a directory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment