Unverified Commit 959c4174 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the chat template for QWen (#83)

parent 94e05770
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, Dict, List, Tuple from typing import Callable, Dict, List, Tuple, Optional
class ChatTemplateStyle(Enum): class ChatTemplateStyle(Enum):
...@@ -13,6 +13,7 @@ class ChatTemplate: ...@@ -13,6 +13,7 @@ class ChatTemplate:
name: str name: str
default_system_prompt: str default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]] role_prefix_and_suffix: Dict[str, Tuple[str]]
stop_str: List[str] = ()
image_token: str = "<image>" image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
...@@ -110,6 +111,7 @@ register_chat_template( ...@@ -110,6 +111,7 @@ register_chat_template(
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=('<|im_end|>',)
) )
) )
......
...@@ -486,6 +486,12 @@ class StreamExecutor: ...@@ -486,6 +486,12 @@ class StreamExecutor:
if clone is None: if clone is None:
clone = self.default_sampling_para.clone() clone = self.default_sampling_para.clone()
setattr(clone, item, value) setattr(clone, item, value)
if self.chat_template.stop_str:
if not clone:
clone = self.default_sampling_para.clone()
clone.stop += self.chat_template.stop_str
return clone or self.default_sampling_para return clone or self.default_sampling_para
def __del__(self): def __del__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment