Unverified Commit 23c05372 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

Add profile (#15)

* remove constraints on model name

* remove duplicate model converter

* add profile

* get eos and bos from server

* update stop_words

* update sequence_length when the last generated token is eos_id

* fix

* fix

* check-in models

* valicate model_name

* make stop_words as property

* debug profiling

* better stats

* fix assistant reponse

* update profile serving

* update

* update
parent 2700abb3
import multiprocessing as mp
import time
import fire
import numpy as np
from llmdeploy.serve.fastertransformer.chatbot import Chatbot
def infer(chatbot, session_id: int, prompt: str, output_seqlen: int,
test_round: int, que: mp.Queue):
stats = []
for i in range(test_round):
timestamps = []
tokens = []
start = time.perf_counter()
for status, res, token in chatbot.stream_infer(
session_id,
prompt,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
timestamps.append(time.perf_counter())
tokens.append(token)
first_token_latency = timestamps[0] - start
token_latency = timestamps[-1] - timestamps[0]
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
chatbot.reset_session()
que.put((session_id, stats))
def warmup(tritonserver_addr: str,
model_name: str,
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...')
def _infer(_chatbot, session_id):
for _ in range(warmup_round):
for _, _, _ in chatbot.stream_infer(
session_id,
prompt='',
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
continue
chatbot.reset_session()
_start = time.perf_counter()
chatbots = [
Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True) for _ in range(concurrency)
]
procs = []
for i, chatbot in enumerate(chatbots):
proc = mp.Process(target=_infer, args=(chatbot, i + 1))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
def main(tritonserver_addr: str,
model_name: str,
concurrency: int = 1,
session_len: int = 2048,
input_seqlen: int = 0,
output_seqlen: int = 512,
test_round: int = 10):
warmup(tritonserver_addr, model_name, concurrency, session_len,
output_seqlen)
# make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1)
que = mp.Queue()
procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True)
proc = mp.Process(target=infer,
args=(chatbot, i + 1, prompt, output_seqlen,
test_round, que))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
elapsed_time = _end - _start
stats = []
while not que.empty():
session_id, _stats = que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)
stats = np.array(stats).reshape(-1, 3)
first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_latency_min = np.min(stats[:, 2], axis=0)
token_latency_max = np.max(stats[:, 2], axis=0)
token_latency_ave = np.mean(stats[:, 2], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / np.sum(stats[:, 2], axis=0)
print(f'\n{"-" * 50}\ncocurrency: {concurrency}, input_tokens: '
f'{input_seqlen}, output_tokens: {output_seqlen}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): '
f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, '
f'{token_latency_ave:.2f}s\n'
f'throughput: {throughput} token/s\n{"-" * 50}')
if __name__ == '__main__':
fire.Fire(main)
import json
import multiprocessing as mp
import os
import random
import time
from typing import List
import fire
import numpy as np
from sentencepiece import SentencePieceProcessor
from llmdeploy.serve.fastertransformer.chatbot import Chatbot
class Tokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
def encode(self, prompts: List):
prompts_token_ids = self.sp_model.Encode(prompts,
add_bos=False,
add_eos=False)
return [len(token_ids) for token_ids in prompts_token_ids]
def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
stats = []
while not req_que.empty():
prompt, input_seqlen, output_seqlen = req_que.get()
print(f'request info: session {session_id}, '
f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
timestamps = []
tokens = []
start = time.perf_counter()
for status, res, token in chatbot.stream_infer(
session_id,
prompt,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
timestamps.append(time.perf_counter())
tokens.append(token)
chatbot.reset_session()
first_token_latency = timestamps[1] - start
token_latency = timestamps[-1] - timestamps[0]
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
res_que.put((session_id, stats))
def warmup(tritonserver_addr: str,
model_name: str,
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...')
def _infer(_chatbot, session_id):
for _ in range(warmup_round):
for _, _, _ in chatbot.stream_infer(
session_id,
prompt='',
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
continue
chatbot.reset_session()
_start = time.perf_counter()
chatbots = [
Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
ignore_eos=True,
profile_generation=True) for _ in range(concurrency)
]
procs = []
for i, chatbot in enumerate(chatbots):
proc = mp.Process(target=_infer, args=(chatbot, i + 1))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)} s')
def read_dataset(tritonserver_addr, tokenizer_path: str, dataset_path: str,
samples: int, test_round: int, session_len: int):
start = time.perf_counter()
with open(dataset_path) as f:
dataset = json.load(f)
dataset = [data for data in dataset if len(data['conversations']) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data['conversations'][0]['value'],
data['conversations'][1]['value']) for data in dataset]
prompts = [prompt for prompt, _ in dataset]
completions = [completion for _, completion in dataset]
print(f'elapsed time for read data: '
f'{round(time.perf_counter() - start, 2)} s')
start = time.perf_counter()
tokenizer = Tokenizer(tokenizer_path)
prompts_token_lens = tokenizer.encode(prompts)
completions_token_lens = tokenizer.encode(completions)
print(f'elapsed time for tokenization: '
f'{round(time.perf_counter() - start, 2)} s')
start = time.perf_counter()
filtered_dataset = []
for (prompt, _), input_len, output_len in zip(dataset, prompts_token_lens,
completions_token_lens):
if input_len + output_len > session_len:
# ignore too long conversation
continue
filtered_dataset.append([prompt, input_len, output_len])
if samples > 0:
filtered_dataset = random.sample(filtered_dataset, samples)
filtered_dataset *= test_round
random.shuffle(filtered_dataset)
que = mp.Queue()
for data in filtered_dataset:
que.put(data)
print(f'elapsed time for filtering: '
f'{round(time.perf_counter() - start, 2)} s')
return que
def main(tritonserver_addr: str,
model_name: str,
tokenizer_path: str,
dataset_path: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 2000,
test_round: int = 1):
warmup(tritonserver_addr, model_name, concurrency, session_len,
session_len)
req_que = read_dataset(tritonserver_addr, tokenizer_path, dataset_path,
samples, test_round, session_len)
res_que = mp.Queue()
procs = []
_start = time.perf_counter()
for i in range(concurrency):
chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
display=False,
profile_serving=True,
ignore_eos=True)
proc = mp.Process(target=infer,
args=(chatbot, i + 1, req_que, res_que))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
_end = time.perf_counter()
elapsed_time = _end - _start
stats = []
while not res_que.empty():
session_id, _stats = res_que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)
stats = np.array(stats).reshape(-1, 3)
first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / elapsed_time
print(f'\n{"-" * 50}\ncocurrency: {concurrency}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\n'
f'throughput: {throughput:.2f} token/s\n{"-" * 50}')
if __name__ == '__main__':
fire.Fire(main)
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry
MODELS = Registry('model', locations=['llmdeploy.model'])
@MODELS.register_module(name='vicuna')
class Vicuna:
def __init__(self):
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501
self.user = 'USER'
self.assistant = 'ASSISTANT'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.system} {self.user}: {prompt} {self.assistant}:'
else:
return f'</s>{self.user}: {prompt} {self.assistant}:'
@property
def stop_words(self):
return None
@MODELS.register_module(name='puyu')
class Puyu:
def __init__(self):
self.system = """meta instruction
You are an AI assistant whose name is InternLM (书生·浦语).
- 书生·浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- 书生·浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文.
conversation""" # noqa: E501
self.user = '<|Human|>'
self.eou = 'െ'
self.assistant = '<|Assistant|>'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.system}\n' \
f'{self.user}:{prompt}{self.eou}\n' \
f'{self.assistant}:'
else:
return f'\n{self.user}:{prompt}{self.eou}\n{self.assistant}:'
@property
def stop_words(self):
return [45623]
def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
model = MODELS.get('vicuna--1')()
prompt = model.get_prompt(prompt='hi')
print(prompt)
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -12,9 +12,9 @@ def input_prompt():
return '\n'.join(iter(input, sentinel))
def main(triton_server_addr: str, model_name: str, session_id: int):
def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
chatbot = Chatbot(triton_server_addr,
chatbot = Chatbot(tritonserver_addr,
model_name,
log_level=log_level,
display=True)
......
......@@ -15,6 +15,7 @@ import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.grpc.service_pb2 import ModelInferResponse
from llmdeploy.model import MODELS
from llmdeploy.serve.fastertransformer.utils import (Postprocessor,
Preprocessor,
prepare_tensor)
......@@ -24,9 +25,9 @@ from llmdeploy.serve.fastertransformer.utils import (Postprocessor,
class Session:
session_id: Union[int, str]
request_id: str = ''
prev: str = '' # history of the session in text format
round_prev: str = '' # previous generated text in the current round
histories: str = '' # history conversations of the session
sequence_length: int = 0 # the total generated token number in the session
prompt: str = ''
response: str = ''
status: int = None # status of the session
......@@ -71,11 +72,9 @@ class Chatbot:
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
stop_words (list): List of token ids that stops the generation
bad_words (list): List of token ids that are not allowed to be
generated.
log_level (int): the level of the log
display (bool): display the generated text on consolo or not
profile_generation (bool): profile token generation or not
"""
def __init__(self,
......@@ -86,18 +85,27 @@ class Chatbot:
top_k: int = 40,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
stop_words: List = None,
bad_words: List = None,
ignore_eos: bool = False,
log_level: int = logging.INFO,
display: bool = False):
display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
self.model_name = model_name
self.model = MODELS.get(self.model_name)()
self._session = None
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
if stop_words is not None:
stop_words = np.array(stop_words, dtype=np.int32)
if bad_words is not None:
bad_words = np.array(bad_words, dtype=np.int32)
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
self.bos_id = self._get_bos()
self.eos_id = self._get_eos()
stop_words = self._stop_words(self.model.stop_words)
bad_words = None
if ignore_eos:
stop_words = None
bad_words = np.array([[[self.eos_id], [1]]], dtype=np.int32)
self.cfg = mmengine.Config(
dict(session_len=session_len,
top_p=top_p,
......@@ -106,10 +114,10 @@ class Chatbot:
repetition_penalty=repetition_penalty,
stop_words=stop_words,
bad_words=bad_words))
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
self.log_level = log_level
self.display = display
self.profile_generation = profile_generation
self.profile_serving = profile_serving
def stream_infer(self,
session_id: int,
......@@ -152,13 +160,16 @@ class Chatbot:
self._session.request_id = request_id
self._session.response = ''
prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session, prompt,
self._session.prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end):
yield status, res, tokens
self._session.prev = self._session.prev + self._session.round_prev
self._session.histories = \
self._session.histories + self._session.prompt + \
self._session.response
def end(self, session_id: int, *args, **kwargs):
"""end a session. Triton inference server will release the session's
......@@ -237,20 +248,41 @@ class Chatbot:
break
if status == StatusCode.TRITON_STREAM_END:
logger.info(f'cancel session {session_id} successfully')
if prev_session.prev:
if prev_session.histories:
logger.warn(f'TODO: start to recover session {session_id}')
else:
logger.info(f'cancel session {session_id} failed: {res}')
return status
def reset_session(self):
self._session = None
def _get_bos(self):
token_ids, _ = self.preprocess('<BOS>')
return token_ids[0][0]
def _get_eos(self):
token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0]
def _stop_words(self, stop_words: List[int]):
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, int) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
# each id in stop_words represents a stop word
# refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
# detailed explanation about fastertransformer's stop_words
stop_word_offsets = range(1, len(stop_words) + 1)
stop_words = np.array([[stop_words,
stop_word_offsets]]).astype(np.int32)
return stop_words
def _get_prompt(self, prompt: str, sequence_start: bool):
if self.model_name == 'vicuna':
if sequence_start:
return f'USER: {prompt} ASSISTANT:'
else:
return f'</s>USER: {prompt} ASSISTANT:'
else:
if self.profile_generation or self.profile_serving:
return prompt
return self.model.get_prompt(prompt, sequence_start)
def _stream_infer(self,
session: Session,
......@@ -271,9 +303,17 @@ class Chatbot:
f'request_output_len is supposed to be None or int, ' \
f'but got {type(request_output_len)}'
if sequence_start:
logger.info(f'session {session.session_id}, clear history since '
f'sequence_start is True')
session.histories = ''
session.sequence_length = 0
input_ids, input_lengths = self.preprocess(prompt)
input_tokens = input_lengths.squeeze()
if self.profile_generation:
yield StatusCode.TRITON_STREAM_ING, \
'ignore preprocessing during profiling generation', 0
if request_output_len is None:
request_output_len = max(
128,
......@@ -293,7 +333,7 @@ class Chatbot:
f'history tokens: {session.sequence_length}')
preseq_length = session.sequence_length
session.round_prev = ''
session.response = ''
que = queue.Queue()
producer = threading.Thread(target=self._stream_producer,
......@@ -302,10 +342,9 @@ class Chatbot:
request_output_len, sequence_start,
sequence_end, preseq_length, cancel))
producer.start()
for state, res, tokens in self.stream_consumer(self.postprocess, que,
session, preseq_length,
cancel, logger,
self.display):
for state, res, tokens in self.stream_consumer(
self.postprocess, que, session, preseq_length, cancel, logger,
self.display, self.profile_generation, self.eos_id):
if state.value < 0:
yield state, res, 0
else:
......@@ -382,21 +421,13 @@ class Chatbot:
@staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel,
logger, display):
def process_response(res):
if session.ai_says is None:
return res, True
index = res.find(session.ai_says)
if index == -1:
return res, False
res = res[index + len(session.ai_says):].replace(session.eoa, '')
return res, True
logger, display, profile_generation, eos_id):
while True:
result = res_queue.get()
if result is None:
yield StatusCode.TRITON_STREAM_END, session.response, \
yield StatusCode.TRITON_STREAM_END, \
session.response[len(session.prompt):], \
session.sequence_length - preseq_length
break
if 'errcode' in result:
......@@ -417,22 +448,35 @@ class Chatbot:
session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - preseq_length
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
sequence_length = sequence_length - 1
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))
if profile_generation:
yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze())
continue
output_str = postprocess(output_ids[:, :, preseq_length:],
sequence_length)
text = output_str[0].decode()
if display:
new_text = text[len(session.round_prev):]
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.round_prev = text
yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze())
session.response = text
if len(session.response) > len(session.prompt):
yield (StatusCode.TRITON_STREAM_ING,
session.response[len(session.prompt):],
sequence_length.squeeze())
except Exception as e:
logger.error(f'catch exception: {e}')
session.response = session.response[len(session.prompt):]
# put session back to queue so that `_stream_infer` can update it in
# `self.sessions`
while not res_queue.empty():
......
......@@ -21,10 +21,14 @@ class Tokenizer:
def encode(self, s: str):
add_bos = False
add_eos = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '')
add_bos = True
return self.model.Encode(s, add_bos=add_bos)
if s == '<EOS>':
s = ''
add_eos = True
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
def decode(self, t: List[int]):
return self.model.Decode(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment