Unverified Commit 406f8c9f authored by q.yao's avatar q.yao Committed by GitHub
Browse files

add llama2 chat template (#140)



* add llama2 template

* update readme and fix lint

* update readme

* add bos

* add bos

* remove bos

* Update model.py

---------
Co-authored-by: default avatargrimoire <yaoqian@pjlab.org.cn>
parent 8ba2d7c5
......@@ -11,9 +11,10 @@ English | [简体中文](README_zh-CN.md)
______________________________________________________________________
## News
## News 🎉
\[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports llama2 7b/13b.
______________________________________________________________________
......
......@@ -11,9 +11,10 @@
______________________________________________________________________
## 更新
## 更新 🎉
\[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 Llama2 7b/13b 模型
______________________________________________________________________
......
......@@ -125,6 +125,50 @@ conversation""" # noqa: E501
return [45623]
@MODELS.register_module(name='llama2')
class Llama2:
"""Chat template of LLaMA2 model."""
def __init__(self):
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501
self.b_inst = B_INST
self.e_inst = E_INST
self.b_sys = B_SYS
self.e_sys = E_SYS
self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
if sequence_start:
return f'<BOS>{self.b_inst} ' \
f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
f'{prompt} {self.e_inst} '
return f'{self.b_inst} {prompt} {self.e_inst} '
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return None
def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment