Unverified Commit 12dc3e14 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

add chat template for Yi (#779)

parent 816022e4
...@@ -716,6 +716,79 @@ class UltraChat(BaseModel): ...@@ -716,6 +716,79 @@ class UltraChat(BaseModel):
return ret return ret
@MODELS.register_module(name='yi')
class Yi(BaseModel):
"""Chat template of Yi model."""
def __init__(self,
system='<|im_start|>system\n',
meta_instruction=None,
user='<|im_start|>user\n',
eoh='<|im_end|>\n',
eoa='<|im_end|>\n',
eosys='<|im_end|>\n',
assistant='<|im_start|>assistant\n',
stop_words=['<|im_end|>', '<|endoftext|>'],
**kwargs):
super().__init__(**kwargs)
self.system = system
self.meta_instruction = meta_instruction
self.user = user
self.eoh = eoh
self.eoa = eoa
self.eosys = eosys
self.assistant = assistant
self.stop_words = stop_words
def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
if self.meta_instruction is None:
return f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
return f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = ''
if self.meta_instruction:
ret += f'{self.system}:{self.meta_instruction}{self.eosys}'
for message in messages:
role = message['role']
content = message['content']
ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}'
ret += f'{self.assistant}'
return ret
def main(model_name: str = 'test'): def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \ assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment