Commit b4fcd59a authored by thomwolf's avatar thomwolf
Browse files

add sentinels in tokenizer

parent 15e53c4e
......@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
import os
import re
import six
from shutil import copyfile
......@@ -31,7 +32,7 @@ SPIECE_UNDERLINE = u'▁'
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
......@@ -56,15 +57,27 @@ class T5Tokenizer(PreTrainedTokenizer):
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
- `extra_ids` add a number of extra ids added to the end of the vocabulary for use as sentinels.
These tokens are accessible as `<extra_id_{%d}>` where `{%d}` is a number between 0 and extra_ids-1.
Extra tokens are indexed from the end of the vocabulary up to beginnning (<extra_id_0> is the last token in the vocabulary)
(like in T5 preprocessing
see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, eos_token="</s>", unk_token="<unk>",
pad_token="<pad>", **kwargs):
pad_token="<pad>", extra_ids=100, additional_special_tokens=None, **kwargs):
# Add extra_ids to the special token list
if extra_ids > 0:
if additional_special_tokens is None:
additional_special_tokens = []
additional_special_tokens.extend([u"<extra_id_{}>".format(i) for i in range(extra_ids)])
super(T5Tokenizer, self).__init__(eos_token=eos_token, unk_token=unk_token,
pad_token=pad_token, **kwargs)
pad_token=pad_token, additional_special_tokens=additional_special_tokens,
**kwargs)
try:
import sentencepiece as spm
......@@ -74,13 +87,14 @@ class T5Tokenizer(PreTrainedTokenizer):
"pip install sentencepiece")
self.vocab_file = vocab_file
self._extra_ids = extra_ids
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
@property
def vocab_size(self):
return self.sp_model.get_piece_size()
return self.sp_model.get_piece_size() + self._extra_ids
def __getstate__(self):
state = self.__dict__.copy()
......@@ -118,11 +132,18 @@ class T5Tokenizer(PreTrainedTokenizer):
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token)
num = int(l[1])
return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index, return_unicode=True):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index)
else:
token = u"<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
return token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment