Unverified Commit 6fc0454b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[LlamaTokenizerFast] nit update `post_processor` on the fly (#23855)

* Update the processor when changing add_eos and add_bos

* fixup

* update

* add a test

* fix failing tests

* fixup
parent 0623f08e
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
from tokenizers import processors
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging from ...utils import is_sentencepiece_available, logging
from ...utils.versions import require_version from ...utils.versions import require_version
...@@ -84,6 +86,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -84,6 +86,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
unk_token="<unk>", unk_token="<unk>",
bos_token="<s>", bos_token="<s>",
eos_token="</s>", eos_token="</s>",
add_bos_token=True,
add_eos_token=False,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -95,10 +99,50 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -95,10 +99,50 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
eos_token=eos_token, eos_token=eos_token,
**kwargs, **kwargs,
) )
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True self.can_save_slow_tokenizer = False if not self.vocab_file else True
def update_post_processor(self):
bos = self.bos_token
bos_token_id = self.bos_token_id
eos = self.eos_token
eos_token_id = self.eos_token_id
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer: if not self.can_save_slow_tokenizer:
raise ValueError( raise ValueError(
......
...@@ -315,6 +315,39 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -315,6 +315,39 @@ class LlamaIntegrationTest(unittest.TestCase):
}, },
) )
def test_fast_special_tokens(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243]
fast_tokenizer.add_eos_token = False
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243]
fast_tokenizer.add_eos_token = True
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243, 2]
slow_tokenizer.add_eos_token = True
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [1, 319, 4559, 1243, 2]
fast_tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [319, 4559, 1243, 2]
slow_tokenzier = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
)
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
assert slow == [319, 4559, 1243, 2]
self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False
@slow @slow
def test_conversion(self): def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge # This is excruciatingly slow since it has to recreate the entire merge
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment