Unverified Commit 9ad815e4 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`LlamaTokenizerFast`] Adds edge cases for the template processor (#26606)

* make sure eos and bos are properly handled for fast tokenizer

* fix code llama as well

* nits

* fix the conversion script as well

* fix failing test
parent 27597fea
...@@ -1192,32 +1192,8 @@ class LlamaConverter(SpmConverter): ...@@ -1192,32 +1192,8 @@ class LlamaConverter(SpmConverter):
return None return None
def post_processor(self): def post_processor(self):
# 3 possible case : # the processor is defined in the LlamaTokenizerFast class.
# - add_bos and add_eos : '<s>:0 $A:0 </s>:0' and '<s>:0 $A:0 </s>:0 <s>:1 $B:1 </s>:1' return None
# - add_bos: '<s>:0 $A:0' and '<s>:0 $A:0 <s>:1 $B:1'
# - add_eos: '$A:0 </s>:0' and '$A:0 </s>:0 $B:1 </s>:1'
add_bos = self.original_tokenizer.add_bos_token
add_eos = self.original_tokenizer.add_eos_token
if add_bos or add_eos:
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
eos = self.original_tokenizer.eos_token
eos_token_id = self.original_tokenizer.eos_token_id
single = f"{(bos+':0 ') * add_bos}$A:0{(' '+eos+':0') if add_eos else ''}"
pair = f"{single}{(' '+bos+':1') * add_bos} $B:1{(' '+eos+':1') if add_eos else ''}"
special_tokens = []
if add_bos:
special_tokens.append((bos, bos_token_id))
if add_eos:
special_tokens.append((eos, eos_token_id))
return processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)
else:
return None
class MarkupLMConverter(Converter): class MarkupLMConverter(Converter):
......
...@@ -178,12 +178,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -178,12 +178,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
""" """
bos = self.bos_token bos = self.bos_token
bos_token_id = self.bos_token_id bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token eos = self.eos_token
eos_token_id = self.eos_token_id eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = [] special_tokens = []
if self.add_bos_token: if self.add_bos_token:
......
...@@ -145,12 +145,16 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -145,12 +145,16 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
""" """
bos = self.bos_token bos = self.bos_token
bos_token_id = self.bos_token_id bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token eos = self.eos_token
eos_token_id = self.eos_token_id eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = [] special_tokens = []
if self.add_bos_token: if self.add_bos_token:
......
...@@ -582,6 +582,19 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -582,6 +582,19 @@ class LlamaIntegrationTest(unittest.TestCase):
# a dummy prefix space is not added by the sp_model as it was de-activated # a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str)) self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
def test_fast_post_processor(self):
tokenizer = LlamaTokenizerFast(
SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False
)
tokenizer.encode(" Hey ")
with self.assertRaises(ValueError):
tokenizer = LlamaTokenizerFast(
SAMPLE_VOCAB, bos_token=None, eos_token="<s>", add_bos_token=True, add_eos_token=False
)
with self.assertRaises(ValueError):
tokenizer = LlamaTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True)
@require_jinja @require_jinja
def test_tokenization_for_chat(self): def test_tokenization_for_chat(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False) tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment