Unverified Commit d0b50023 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[Llama Tokenizer] Fast llama template (#22959)

* update template processing for llama fast to add eos

* style

* update

* adress training from new issue

* fix

* update

* special tokens can be given even if not used
parent 00bc6e20
......@@ -1158,13 +1158,28 @@ class LlamaConverter(SpmConverter):
return None
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A",
pair="<s> $A $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
],
)
# 3 possible case :
# - add_bos and add_eos : '<s>:0 $A:0 </s>:0' and '<s>:0 $A:0 </s>:0 <s>:1 $B:1 </s>:1'
# - add_bos: '<s>:0 $A:0' and '<s>:0 $A:0 <s>:1 $B:1'
# - add_eos: '$A:0 </s>:0' and '$A:0 </s>:0 $B:1 </s>:1'
add_bos = self.original_tokenizer.add_bos_token
add_eos = self.original_tokenizer.add_eos_token
if add_bos or add_eos:
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
eos = self.original_tokenizer.eos_token
eos_token_id = self.original_tokenizer.eos_token_id
single = f"{(bos+':0 ') * add_bos}$A:0{(' '+eos+':0') * add_eos}"
pair = f"{single}{(' '+bos+':1') * add_bos} $B:1{(' '+eos+':1') * add_eos}"
special_tokens = [(bos, bos_token_id), (eos, eos_token_id)]
return processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)
else:
return None
class MarkupLMConverter(Converter):
......
......@@ -706,7 +706,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
kwargs["unk_token"] = unk_token
if tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel":
if tokenizer_json["pre_tokenizer"] is not None and tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel":
kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment