Unverified Commit b0b722ee authored by Xinwei Xiong's avatar Xinwei Xiong Committed by GitHub
Browse files

Refactor ChatTemplate for Enhanced Clarity and Efficiency (#201)

parent 01b07ea3
......@@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum):
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]]
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages):
if self.style == ChatTemplateStyle.PLAIN:
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system":
return (
self.role_prefix_and_suffix["user"][0]
+ self.role_prefix_and_suffix["system"][0],
self.role_prefix_and_suffix["system"][1],
)
elif (
len(hist_messages) == 1
and role == "user"
and hist_messages[0]["content"] is not None
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
if self.style == ChatTemplateStyle.LLAMA2:
if role == "system" and not hist_messages:
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
return (user_prefix + system_prefix, system_suffix)
elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
return ("", suffix)
return prefix, suffix
def get_prompt(self, messages: List[Dict]) -> str:
prompt = ""
for i in range(len(messages)):
role, content = messages[i]["role"], messages[i]["content"]
for i, message in enumerate(messages):
role, content = message["role"], message["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix
prompt += f"{prefix}{content}{suffix}"
return prompt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment