Unverified Commit 46493789 authored by Atream's avatar Atream Committed by GitHub
Browse files

fix chat template encoding

parent 449a83df
...@@ -387,23 +387,11 @@ class BalanceServeInterface(BackendInterfaceBase): ...@@ -387,23 +387,11 @@ class BalanceServeInterface(BackendInterfaceBase):
return input_ids return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List): def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages: input_str: str = self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template # drop <think> token in chat template
if input_str.endswith('<think>\n'): if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')] input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device) input_ids = self.tokenizer.encode(input_str, return_tensors="pt", add_special_tokens=False).to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}") logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids return input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment