Unverified Commit 3de27ead authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

update internlm‘s chat template (#54)

* update internlm model

* update

* update

* update

* update

* update temperature, topk and top_p

* update

* update

* loosen log level
parent d2c9caa4
...@@ -23,30 +23,42 @@ class Vicuna: ...@@ -23,30 +23,42 @@ class Vicuna:
return None return None
@MODELS.register_module(name='puyu') @MODELS.register_module(name='internlm')
class Puyu: class InternLM:
def __init__(self): def __init__(self):
self.system = """meta instruction self.system = ''
You are an AI assistant whose name is InternLM (书生·浦语). self.user = '<|User|>'
- 书生·浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. self.eoh = '<eoh>'
- 书生·浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文. self.eoa = '<eoa>'
conversation""" # noqa: E501 self.assistant = '<|Bot|>'
self.user = '<|Human|>'
self.eou = 'െ'
self.assistant = '<|Assistant|>'
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
if sequence_start: if sequence_start:
return f'{self.system}\n' \ return f'{self.system}\n' \
f'{self.user}:{prompt}{self.eou}\n' \ f'{self.user}:{prompt}{self.eoh}\n' \
f'{self.assistant}:' f'{self.assistant}:'
else: else:
return f'\n{self.user}:{prompt}{self.eou}\n{self.assistant}:' return f'\n{self.user}:{prompt}{self.eoh}\n' \
f'{self.assistant}:'
@property @property
def stop_words(self): def stop_words(self):
return [45623] return [103027, 103028]
@MODELS.register_module(name='llama')
class Llama:
def __init__(self):
pass
def get_prompt(self, prompt, sequence_start=True):
return prompt
@property
def stop_words(self):
return None
def main(model_name: str = 'test'): def main(model_name: str = 'test'):
......
...@@ -13,7 +13,7 @@ def input_prompt(): ...@@ -13,7 +13,7 @@ def input_prompt():
def main(tritonserver_addr: str, model_name: str, session_id: int = 1): def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
chatbot = Chatbot(tritonserver_addr, chatbot = Chatbot(tritonserver_addr,
model_name, model_name,
log_level=log_level, log_level=log_level,
...@@ -33,7 +33,6 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1): ...@@ -33,7 +33,6 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
request_id=request_id, request_id=request_id,
request_output_len=512): request_output_len=512):
continue continue
print(f'session {session_id}, {status}, {tokens}, {res}')
nth_round += 1 nth_round += 1
......
...@@ -34,6 +34,7 @@ class Session: ...@@ -34,6 +34,7 @@ class Session:
class StatusCode(Enum): class StatusCode(Enum):
TRITON_STREAM_END = 0 # end of streaming TRITON_STREAM_END = 0 # end of streaming
TRITON_STREAM_ING = 1 # response is in streaming TRITON_STREAM_ING = 1 # response is in streaming
TRITON_SESSION_READY = 2 # session is ready for inference
TRITON_SERVER_ERR = -1 # triton server's error TRITON_SERVER_ERR = -1 # triton server's error
TRITON_SESSION_CLOSED = -2 # session has been closed TRITON_SESSION_CLOSED = -2 # session has been closed
TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit
...@@ -79,9 +80,9 @@ class Chatbot: ...@@ -79,9 +80,9 @@ class Chatbot:
tritonserver_addr: str, tritonserver_addr: str,
model_name: str, model_name: str,
session_len: int = 2048, session_len: int = 2048,
top_p: float = 1.0, top_p: float = 0.8,
top_k: int = 40, top_k: int = None,
temperature: float = 1.0, temperature: float = 0.8,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
ignore_eos: bool = False, ignore_eos: bool = False,
log_level: int = logging.INFO, log_level: int = logging.INFO,
...@@ -340,6 +341,7 @@ class Chatbot: ...@@ -340,6 +341,7 @@ class Chatbot:
preseq_length = session.sequence_length preseq_length = session.sequence_length
session.response = '' session.response = ''
session.status = StatusCode.TRITON_SESSION_READY
que = queue.Queue() que = queue.Queue()
producer = threading.Thread(target=self._stream_producer, producer = threading.Thread(target=self._stream_producer,
...@@ -375,8 +377,6 @@ class Chatbot: ...@@ -375,8 +377,6 @@ class Chatbot:
prepare_tensor('input_ids', input_ids), prepare_tensor('input_ids', input_ids),
prepare_tensor('input_lengths', input_lengths), prepare_tensor('input_lengths', input_lengths),
prepare_tensor('request_output_len', request_output_len), prepare_tensor('request_output_len', request_output_len),
prepare_tensor('runtime_top_k',
cfg.top_k * np.ones((1, 1), dtype=np.uint32)),
prepare_tensor('runtime_top_p', prepare_tensor('runtime_top_p',
cfg.top_p * np.ones((1, 1), dtype=np.float32)), cfg.top_p * np.ones((1, 1), dtype=np.float32)),
prepare_tensor( prepare_tensor(
...@@ -389,6 +389,10 @@ class Chatbot: ...@@ -389,6 +389,10 @@ class Chatbot:
prepare_tensor('step', prepare_tensor('step',
preseq_length * np.ones((1, 1), dtype=np.int32)) preseq_length * np.ones((1, 1), dtype=np.int32))
] ]
if cfg.top_k is not None:
inputs += prepare_tensor(
'runtime_top_k',
cfg.top_k * np.ones((1, 1), dtype=np.uint32)),
if cfg.stop_words is not None: if cfg.stop_words is not None:
inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] inputs += [prepare_tensor('stop_words_list', cfg.stop_words)]
if cfg.bad_words is not None: if cfg.bad_words is not None:
...@@ -435,6 +439,7 @@ class Chatbot: ...@@ -435,6 +439,7 @@ class Chatbot:
yield StatusCode.TRITON_STREAM_END, \ yield StatusCode.TRITON_STREAM_END, \
session.response[len(session.prompt):], \ session.response[len(session.prompt):], \
session.sequence_length - preseq_length session.sequence_length - preseq_length
session.status = StatusCode.TRITON_STREAM_END
break break
if 'errcode' in result: if 'errcode' in result:
logger.error(f'got error from turbomind, code ' logger.error(f'got error from turbomind, code '
...@@ -472,10 +477,16 @@ class Chatbot: ...@@ -472,10 +477,16 @@ class Chatbot:
sequence_length) sequence_length)
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
new_text = text[len(session.response):] if len(text) > len(session.prompt):
print(new_text, end='', flush=True) if session.status == StatusCode.TRITON_SESSION_READY:
new_text = text[len(session.prompt):]
session.status = StatusCode.TRITON_STREAM_ING
else:
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text session.response = text
if len(session.response) > len(session.prompt): if len(session.response) > len(session.prompt):
session.status = StatusCode.TRITON_STREAM_ING
yield (StatusCode.TRITON_STREAM_ING, yield (StatusCode.TRITON_STREAM_ING,
session.response[len(session.prompt):], session.response[len(session.prompt):],
sequence_length.squeeze()) sequence_length.squeeze())
......
...@@ -280,13 +280,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -280,13 +280,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
if osp.exists(tokenizer_path): if osp.exists(tokenizer_path):
shutil.copy(tokenizer_path, shutil.copy(tokenizer_path,
osp.join(triton_models_path, 'tokenizer/tokenizer.model')) osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
for json_file in os.listdir(model_path): for _file in os.listdir(model_path):
if json_file.endswith( if _file.endswith('.json') or _file.endswith('.py'):
'.json') and json_file != 'pytorch_model.bin.index.json': json_path = osp.join(model_path, _file)
json_path = osp.join(model_path, json_file) shutil.copy(json_path,
shutil.copy( osp.join(triton_models_path, 'tokenizer', _file))
json_path,
osp.join(triton_models_path, 'tokenizer', json_file))
else: else:
print(f'tokenizer model {tokenizer_path} does not exist') print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1) exit(-1)
......
...@@ -27,12 +27,14 @@ class Tokenizer: ...@@ -27,12 +27,14 @@ class Tokenizer:
if not osp.exists(backend_tokenizer_file): if not osp.exists(backend_tokenizer_file):
print('WARNING: Can not find tokenizer.json. ' print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.') 'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_folder) self.model = AutoTokenizer.from_pretrained(model_folder,
trust_remote_code=True)
self.vocab_size = self.model.vocab_size self.vocab_size = self.model.vocab_size
self.start_id = self.model.bos_token_id self.start_id = self.model.bos_token_id
self.end_id = self.model.eos_token_id self.end_id = self.model.eos_token_id
# save tokenizer.json to reuse # save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file): if not osp.exists(backend_tokenizer_file) and \
hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
......
...@@ -29,12 +29,14 @@ class Tokenizer: ...@@ -29,12 +29,14 @@ class Tokenizer:
if not osp.exists(backend_tokenizer_file): if not osp.exists(backend_tokenizer_file):
print('WARNING: Can not find tokenizer.json. ' print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.') 'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_folder) self.model = AutoTokenizer.from_pretrained(model_folder,
trust_remote_code=True)
self.vocab_size = self.model.vocab_size self.vocab_size = self.model.vocab_size
self.start_id = self.model.bos_token_id self.start_id = self.model.bos_token_id
self.end_id = self.model.eos_token_id self.end_id = self.model.eos_token_id
# save tokenizer.json to reuse # save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file): if not osp.exists(backend_tokenizer_file) and \
hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import random import random
...@@ -50,8 +51,8 @@ def main(model_name, model_path, session_id: int = 1): ...@@ -50,8 +51,8 @@ def main(model_name, model_path, session_id: int = 1):
random_seed=seed if nth_round == 1 else None): random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = tokenizer.decode( response = tokenizer.decode(res[step:],
res[step:], skip_special_tokens=True) skip_special_tokens=True)
print(f'session {session_id}, {tokens}, {response}') print(f'session {session_id}, {tokens}, {response}')
# update step # update step
step = tokens - 1 step = tokens - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment