Unverified Commit b0b722ee authored by Xinwei Xiong's avatar Xinwei Xiong Committed by GitHub
Browse files

Refactor ChatTemplate for Enhanced Clarity and Efficiency (#201)

parent 01b07ea3
...@@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum): ...@@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum):
class ChatTemplate: class ChatTemplate:
name: str name: str
default_system_prompt: str default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]] role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = () stop_str: List[str] = ()
image_token: str = "<image>" image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages): def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
if self.style == ChatTemplateStyle.PLAIN: prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2: if self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system": if role == "system" and not hist_messages:
return ( user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
self.role_prefix_and_suffix["user"][0] system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
+ self.role_prefix_and_suffix["system"][0], return (user_prefix + system_prefix, system_suffix)
self.role_prefix_and_suffix["system"][1], elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
) return ("", suffix)
elif (
len(hist_messages) == 1 return prefix, suffix
and role == "user"
and hist_messages[0]["content"] is not None def get_prompt(self, messages: List[Dict]) -> str:
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
prompt = "" prompt = ""
for i in range(len(messages)): for i, message in enumerate(messages):
role, content = messages[i]["role"], messages[i]["content"] role, content = message["role"], message["content"]
if role == "system" and content is None: if role == "system" and content is None:
content = self.default_system_prompt content = self.default_system_prompt
if content is None: if content is None:
continue continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i]) prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix prompt += f"{prefix}{content}{suffix}"
return prompt return prompt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment