Unverified Commit efb2ba66 authored by Iskren Ivov Chernev's avatar Iskren Ivov Chernev Committed by GitHub
Browse files

Better handling missing SYS in llama conversation tokenizer (#24997)

* Better handling missing SYS in llama conversation tokenizer

The existing code failed to add SYS if the conversation has history
without SYS, but did modify the passed conversation as it did.

Rearrange the code so modification to the conversation object are taken
into account for token id generation.

* Fix formatting with black

* Avoid one-liners

* Also fix fast tokenizer

* Drop List decl
parent 67049231
...@@ -356,6 +356,17 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -356,6 +356,17 @@ class LlamaTokenizer(PreTrainedTokenizer):
`List[int]`: `List[int]`:
Input ids for the conversation. Input ids for the conversation.
""" """
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")
dialogue = list(conversation.iter_texts()) dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all( if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]] [not is_user for is_user, msg in dialogue[1::2]]
...@@ -365,14 +376,6 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -365,14 +376,6 @@ class LlamaTokenizer(PreTrainedTokenizer):
) )
dialog_tokens: List[int] = [] dialog_tokens: List[int] = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])
dialog_tokens += sum( dialog_tokens += sum(
[ [
[self.bos_token_id] [self.bos_token_id]
...@@ -384,8 +387,6 @@ class LlamaTokenizer(PreTrainedTokenizer): ...@@ -384,8 +387,6 @@ class LlamaTokenizer(PreTrainedTokenizer):
], ],
[], [],
) )
if not (dialogue[-1][0]):
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode( dialog_tokens += [self.bos_token_id] + self.encode(
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
) )
......
...@@ -212,6 +212,17 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -212,6 +212,17 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
`List[int]`: `List[int]`:
Input ids for the conversation. Input ids for the conversation.
""" """
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif conversation.new_user_input:
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
else:
raise ValueError("Last message must be from user")
dialogue = list(conversation.iter_texts()) dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all( if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]] [not is_user for is_user, msg in dialogue[1::2]]
...@@ -221,14 +232,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -221,14 +232,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
) )
dialog_tokens = [] dialog_tokens = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
)
elif not dialogue[0][1].startswith(B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + dialogue[0][1])
dialog_tokens += sum( dialog_tokens += sum(
[ [
[self.bos_token_id] [self.bos_token_id]
...@@ -240,8 +243,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -240,8 +243,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
], ],
[], [],
) )
if not (dialogue[-1][0]):
raise ValueError(f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode( dialog_tokens += [self.bos_token_id] + self.encode(
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment