Unverified Commit e6ec24fa authored by Anthony MOI's avatar Anthony MOI
Browse files

Better added_tokens handling

parent 599db139
...@@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object): ...@@ -1413,6 +1413,9 @@ class PreTrainedTokenizer(object):
class PreTrainedTokenizerFast(PreTrainedTokenizer): class PreTrainedTokenizerFast(PreTrainedTokenizer):
_tokenizer = None
_decoder = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(PreTrainedTokenizerFast, self).__init__(**kwargs) super(PreTrainedTokenizerFast, self).__init__(**kwargs)
...@@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ...@@ -1435,8 +1438,49 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def __len__(self): def __len__(self):
return self.tokenizer.get_vocab_size(with_added_tokens=True) return self.tokenizer.get_vocab_size(with_added_tokens=True)
@PreTrainedTokenizer.bos_token.setter
def bos_token(self, value):
self._bos_token = value
self._update_special_tokens()
@PreTrainedTokenizer.eos_token.setter
def eos_token(self, value):
self._eos_token = value
self._update_special_tokens()
@PreTrainedTokenizer.unk_token.setter
def unk_token(self, value):
self._unk_token = value
self._update_special_tokens()
@PreTrainedTokenizer.sep_token.setter
def sep_token(self, value):
self._sep_token = value
self._update_special_tokens()
@PreTrainedTokenizer.pad_token.setter
def pad_token(self, value):
self._pad_token = value
self._update_special_tokens()
@PreTrainedTokenizer.cls_token.setter
def cls_token(self, value):
self._cls_token = value
self._update_special_tokens()
@PreTrainedTokenizer.mask_token.setter
def mask_token(self, value):
self._mask_token = value
self._update_special_tokens()
@PreTrainedTokenizer.additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
self._update_special_tokens()
def _update_special_tokens(self): def _update_special_tokens(self):
self.tokenizer.add_special_tokens(self.all_special_tokens) if self._tokenizer is not None:
self._tokenizer.add_special_tokens(self.all_special_tokens)
@staticmethod @staticmethod
def _convert_encoding( def _convert_encoding(
...@@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ...@@ -1522,6 +1566,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens):
self.tokenizer.add_tokens(new_tokens) self.tokenizer.add_tokens(new_tokens)
def add_special_tokens(self, special_tokens_dict):
added = super().add_special_tokens(special_tokens_dict)
self._update_special_tokens()
return added
def encode_batch( def encode_batch(
self, self,
texts, texts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment