Unverified Commit 7785142d authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Pass chat template args including meta_prompt to model (#225)

* pass args like meta_prompt to model

* update chatbot

* update

* rollback

* update llama2 and qwen

* refine
parent f44ef17c
...@@ -8,12 +8,18 @@ MODELS = Registry('model', locations=['lmdeploy.model']) ...@@ -8,12 +8,18 @@ MODELS = Registry('model', locations=['lmdeploy.model'])
class BaseModel: class BaseModel:
"""Base model.""" """Base model."""
def __init__(self): def __init__(self,
self.session_len = 2048 session_len=2048,
self.top_p = 0.8 top_p=0.8,
self.top_k = None top_k=None,
self.temperature = 0.8 temperature=0.8,
self.repetition_penalty = 1.0 repetition_penalty=1.0,
**kwargs):
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.repetition_penalty = repetition_penalty
@staticmethod @staticmethod
def get_prompt(prompt, sequence_start=True): def get_prompt(prompt, sequence_start=True):
...@@ -39,11 +45,16 @@ class BaseModel: ...@@ -39,11 +45,16 @@ class BaseModel:
class Vicuna(BaseModel): class Vicuna(BaseModel):
"""Chat template of vicuna model.""" """Chat template of vicuna model."""
def __init__(self): def __init__(
super().__init__() self,
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501 system="""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """, # noqa: E501
self.user = 'USER' user='USER',
self.assistant = 'ASSISTANT' assistant='ASSISTANT',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.assistant = assistant
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the """Return the prompt that is concatenated with other elements in the
...@@ -65,21 +76,27 @@ class Vicuna(BaseModel): ...@@ -65,21 +76,27 @@ class Vicuna(BaseModel):
@MODELS.register_module(name='internlm') @MODELS.register_module(name='internlm')
class InternLM(BaseModel): class InternLM(BaseModel):
def __init__(self): def __init__(self, **kwargs):
super().__init__() super().__init__(**kwargs)
@MODELS.register_module(name='internlm-chat-7b') @MODELS.register_module(name='internlm-chat-7b')
class InternLMChat7B(BaseModel): class InternLMChat7B(BaseModel):
"""Chat template of InternLM model.""" """Chat template of InternLM model."""
def __init__(self): def __init__(self,
super().__init__() system='',
self.system = '' user='<|User|>',
self.user = '<|User|>' eoh='<eoh>',
self.eoh = '<eoh>' eoa='<eoa>',
self.eoa = '<eoa>' assistant='<|Bot|>',
self.assistant = '<|Bot|>' **kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.eoh = eoh
self.eoa = eoa
self.assistant = assistant
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the """Return the prompt that is concatenated with other elements in the
...@@ -108,39 +125,77 @@ class InternLMChat7B(BaseModel): ...@@ -108,39 +125,77 @@ class InternLMChat7B(BaseModel):
@MODELS.register_module(name='internlm-chat-7b-8k') @MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B): class InternLMChat7B8K(InternLMChat7B):
def __init__(self): def __init__(self, session_len=8192, **kwargs):
super(InternLMChat7B8K, self).__init__() super(InternLMChat7B8K, self).__init__(**kwargs)
self.session_len = 8192 self.session_len = session_len
@MODELS.register_module(name='baichuan-7b') @MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel): class Baichuan7B(BaseModel):
def __init__(self): def __init__(self, repetition_penalty=1.1, **kwargs):
super().__init__() super().__init__(**kwargs)
self.repetition_penalty = 1.1 self.repetition_penalty = repetition_penalty
@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
AI Laboratory."""
def __init__(self,
meta_instruction='',
user='<|Human|>: ',
eoh='',
eosys='',
assistant='<|Assistant|>: ',
system='<|System|>: ',
**kwargs):
super().__init__(**kwargs)
self.meta_instruction = meta_instruction
self.user = user
self.eoh = eoh
self.eosys = eosys
self.assistant = assistant
self.system = system
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}\n' \
f'{self.user}{prompt}{self.eoh}\n' \
f'{self.assistant}'
else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]
@MODELS.register_module(name='llama2') @MODELS.register_module(name='llama2')
class Llama2(BaseModel): class Llama2(BaseModel):
"""Chat template of LLaMA2 model.""" """Chat template of LLaMA2 model."""
def __init__(self): def __init__(
super().__init__() self,
B_INST, E_INST = '[INST]', '[/INST]' b_inst='[INST]',
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n' e_inst='[/INST]',
b_sys='<<SYS>>\n',
DEFAULT_SYSTEM_PROMPT = """\ e_sys='\n<</SYS>>\n\n',
default_sys_prompt="""\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501 If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501
session_len=4096,
self.b_inst = B_INST **kwargs):
self.e_inst = E_INST super().__init__(**kwargs)
self.b_sys = B_SYS self.b_inst = b_inst
self.e_sys = E_SYS self.e_inst = e_inst
self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT self.b_sys = b_sys
self.session_len = 4096 self.e_sys = e_sys
self.default_sys_prompt = default_sys_prompt
self.session_len = session_len
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the """Return the prompt that is concatenated with other elements in the
...@@ -165,16 +220,24 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -165,16 +220,24 @@ If a question does not make any sense, or is not factually coherent, explain why
class Qwen7BChat(BaseModel): class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat.""" """Chat template for Qwen-7B-Chat."""
def __init__(self): def __init__(self,
super().__init__() session_len=8192,
self.session_len = 8192 top_p=0.5,
self.top_p = 0.5 top_k=40,
self.top_k = 40 temperature=1.0,
self.temperature = 1.0 im_start='<|im_start|>',
im_end='<|im_end|>',
self.im_start = '<|im_start|>' system='You are a helpful assistant.',
self.im_end = '<|im_end|>' **kwargs):
self.system = 'You are a helpful assistant.' super().__init__(**kwargs)
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.im_start = im_start
self.im_end = im_end
self.system = system
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
if sequence_start: if sequence_start:
......
...@@ -76,7 +76,8 @@ class Chatbot: ...@@ -76,7 +76,8 @@ class Chatbot:
log_level: int = logging.INFO, log_level: int = logging.INFO,
display: bool = False, display: bool = False,
profile_generation: bool = False, profile_generation: bool = False,
profile_serving: bool = False): profile_serving: bool = False,
**model_kwargs):
self.tritonserver_addr = tritonserver_addr self.tritonserver_addr = tritonserver_addr
self.model_name = model_name self.model_name = model_name
if self.model_name == '': if self.model_name == '':
...@@ -84,7 +85,7 @@ class Chatbot: ...@@ -84,7 +85,7 @@ class Chatbot:
assert self.model_name in MODELS.module_dict.keys(), \ assert self.model_name in MODELS.module_dict.keys(), \
f"'{self.model_name}' is not supported. " \ f"'{self.model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}' f'The supported models are: {MODELS.module_dict.keys()}'
self.model = MODELS.get(self.model_name)() self.model = MODELS.get(self.model_name)(**model_kwargs)
self._session = None self._session = None
self.preprocess = Preprocessor(tritonserver_addr) self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr) self.postprocess = Postprocessor(tritonserver_addr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment