Unverified Commit 7785142d authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Pass chat template args including meta_prompt to model (#225)

* pass args like meta_prompt to model

* update chatbot

* update

* rollback

* update llama2 and qwen

* refine
parent f44ef17c
......@@ -8,12 +8,18 @@ MODELS = Registry('model', locations=['lmdeploy.model'])
class BaseModel:
"""Base model."""
def __init__(self):
self.session_len = 2048
self.top_p = 0.8
self.top_k = None
self.temperature = 0.8
self.repetition_penalty = 1.0
def __init__(self,
session_len=2048,
top_p=0.8,
top_k=None,
temperature=0.8,
repetition_penalty=1.0,
**kwargs):
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.repetition_penalty = repetition_penalty
@staticmethod
def get_prompt(prompt, sequence_start=True):
......@@ -39,11 +45,16 @@ class BaseModel:
class Vicuna(BaseModel):
"""Chat template of vicuna model."""
def __init__(self):
super().__init__()
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501
self.user = 'USER'
self.assistant = 'ASSISTANT'
def __init__(
self,
system="""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """, # noqa: E501
user='USER',
assistant='ASSISTANT',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.assistant = assistant
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
......@@ -65,21 +76,27 @@ class Vicuna(BaseModel):
@MODELS.register_module(name='internlm')
class InternLM(BaseModel):
def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
@MODELS.register_module(name='internlm-chat-7b')
class InternLMChat7B(BaseModel):
"""Chat template of InternLM model."""
def __init__(self):
super().__init__()
self.system = ''
self.user = '<|User|>'
self.eoh = '<eoh>'
self.eoa = '<eoa>'
self.assistant = '<|Bot|>'
def __init__(self,
system='',
user='<|User|>',
eoh='<eoh>',
eoa='<eoa>',
assistant='<|Bot|>',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.eoh = eoh
self.eoa = eoa
self.assistant = assistant
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
......@@ -108,39 +125,77 @@ class InternLMChat7B(BaseModel):
@MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B):
def __init__(self):
super(InternLMChat7B8K, self).__init__()
self.session_len = 8192
def __init__(self, session_len=8192, **kwargs):
super(InternLMChat7B8K, self).__init__(**kwargs)
self.session_len = session_len
@MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel):
def __init__(self):
super().__init__()
self.repetition_penalty = 1.1
def __init__(self, repetition_penalty=1.1, **kwargs):
super().__init__(**kwargs)
self.repetition_penalty = repetition_penalty
@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
AI Laboratory."""
def __init__(self,
meta_instruction='',
user='<|Human|>: ',
eoh='',
eosys='',
assistant='<|Assistant|>: ',
system='<|System|>: ',
**kwargs):
super().__init__(**kwargs)
self.meta_instruction = meta_instruction
self.user = user
self.eoh = eoh
self.eosys = eosys
self.assistant = assistant
self.system = system
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}\n' \
f'{self.user}{prompt}{self.eoh}\n' \
f'{self.assistant}'
else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]
@MODELS.register_module(name='llama2')
class Llama2(BaseModel):
"""Chat template of LLaMA2 model."""
def __init__(self):
super().__init__()
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
DEFAULT_SYSTEM_PROMPT = """\
def __init__(
self,
b_inst='[INST]',
e_inst='[/INST]',
b_sys='<<SYS>>\n',
e_sys='\n<</SYS>>\n\n',
default_sys_prompt="""\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501
self.b_inst = B_INST
self.e_inst = E_INST
self.b_sys = B_SYS
self.e_sys = E_SYS
self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT
self.session_len = 4096
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501
session_len=4096,
**kwargs):
super().__init__(**kwargs)
self.b_inst = b_inst
self.e_inst = e_inst
self.b_sys = b_sys
self.e_sys = e_sys
self.default_sys_prompt = default_sys_prompt
self.session_len = session_len
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
......@@ -165,16 +220,24 @@ If a question does not make any sense, or is not factually coherent, explain why
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""
def __init__(self):
super().__init__()
self.session_len = 8192
self.top_p = 0.5
self.top_k = 40
self.temperature = 1.0
self.im_start = '<|im_start|>'
self.im_end = '<|im_end|>'
self.system = 'You are a helpful assistant.'
def __init__(self,
session_len=8192,
top_p=0.5,
top_k=40,
temperature=1.0,
im_start='<|im_start|>',
im_end='<|im_end|>',
system='You are a helpful assistant.',
**kwargs):
super().__init__(**kwargs)
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.im_start = im_start
self.im_end = im_end
self.system = system
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
......
......@@ -76,7 +76,8 @@ class Chatbot:
log_level: int = logging.INFO,
display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False):
profile_serving: bool = False,
**model_kwargs):
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
if self.model_name == '':
......@@ -84,7 +85,7 @@ class Chatbot:
assert self.model_name in MODELS.module_dict.keys(), \
f"'{self.model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
self.model = MODELS.get(self.model_name)()
self.model = MODELS.get(self.model_name)(**model_kwargs)
self._session = None
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment