"vscode:/vscode.git/clone" did not exist on "e7aa64838cc604abf7a49e69ca0ffe7af683d8ca"
Commit 0b3d45eb authored by Stefan Schweter's avatar Stefan Schweter
Browse files

camembert: add implementation for save_vocabulary method

parent 33753d91
......@@ -16,9 +16,14 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import os
from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
......@@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
......@@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment