"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f48f9c250fe7f6027c7445f8be04bb4499583f58"
Unverified Commit ef2adb04 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

add chatbot (#2)

parent a75c0a47
...@@ -50,4 +50,4 @@ repos: ...@@ -50,4 +50,4 @@ repos:
rev: v0.2.0 rev: v0.2.0
hooks: hooks:
- id: check-copyright - id: check-copyright
args: [] args: ["llmdeploy"]
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import os
import fire
from llmdeploy.serve.fastertransformer.chatbot import Chatbot
def input_prompt():
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def main(triton_server_addr: str, model_name: str, session_id: int):
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
chatbot = Chatbot(triton_server_addr,
model_name,
log_level=log_level,
display=True)
nth_round = 1
while True:
prompt = input_prompt()
if prompt == 'exit':
exit(0)
elif prompt == 'end':
chatbot.end(session_id)
else:
request_id = f'{session_id}-{nth_round}'
for status, res, tokens in chatbot.stream_infer(
session_id, prompt, request_id=request_id):
continue
print(f'session {session_id}, {status}, {tokens}, {res}')
nth_round += 1
if __name__ == '__main__':
fire.Fire(main)
# Copyright (c) OpenMMLab. All rights reserved.
from llmdeploy.serve.fastertransformer.chatbot import \
Chatbot # noqa: F401,F403
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import queue
import random
import threading
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import List, Union
import google.protobuf.json_format
import mmengine
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.grpc.service_pb2 import ModelInferResponse
from llmdeploy.serve.fastertransformer.utils import (Postprocessor,
Preprocessor,
prepare_tensor)
@dataclass
class Session:
session_id: Union[int, str]
request_id: str = ''
prev: str = '' # history of the session in text format
round_prev: str = '' # previous generated text in the current round
sequence_length: int = 0 # the total generated token number in the session
response: str = ''
status: int = None # status of the session
class StatusCode(Enum):
TRITON_STREAM_END = 0 # end of streaming
TRITON_STREAM_ING = 1 # response is in streaming
TRITON_SERVER_ERR = -1 # triton server's error
TRITON_SESSION_CLOSED = -2 # session has been closed
TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit
TRITON_SESSION_INVALID_ARG = -4 # invalid argument
def stream_callback(que, result, error):
if error:
print(error)
que.put(dict(errcode=StatusCode.TRITON_SERVER_ERR, errmsg=f'{error}'))
else:
que.put(result.get_response(as_json=True))
def get_logger(log_file=None, log_level=logging.INFO):
from .utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger
class Chatbot:
"""Chatbot for LLaMA series models with fastertransformer as inference
engine.
Args:
tritonserver_addr (str): communicating address '<ip>:<port>' of
triton inference server
model_name (str): name of the to-be-deployed mode
session_len (int): the maximum context length of the model
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary tokens to
keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
stop_words (list): List of token ids that stops the generation
bad_words (list): List of token ids that are not allowed to be
generated.
log_level (int): the level of the log
display (bool): display the generated text on consolo or not
"""
def __init__(self,
tritonserver_addr: str,
model_name: str,
session_len: int = 2048,
top_p: float = 1.0,
top_k: int = 40,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
stop_words: List = None,
bad_words: List = None,
log_level: int = logging.INFO,
display: bool = False):
self._session = None
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
if stop_words is not None:
stop_words = np.array(stop_words, dtype=np.int32)
if bad_words is not None:
bad_words = np.array(bad_words, dtype=np.int32)
self.cfg = mmengine.Config(
dict(session_len=session_len,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
bad_words=bad_words))
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
self.log_level = log_level
self.display = display
def stream_infer(self,
session_id: int,
prompt: str,
request_id: str = '',
request_output_len: int = None,
sequence_start: bool = False,
sequence_end: bool = False,
*args,
**kwargs):
"""Start a new round conversion of a session.
Args:
session_id (int): the identical id of a session
prompt (str): user's prompt in this round conversation
request_id (str): the identical id of this round conversation
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
iterator: The generated content by chatbot
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}')
if self._session is None:
sequence_start = True
self._session = Session(session_id=session_id)
elif self._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
yield StatusCode.TRITON_SESSION_CLOSED, '', 0
return
self._session.status = 1
self._session.request_id = request_id
self._session.response = ''
prompt = self._get_prompt(prompt, sequence_start)
for status, res, tokens in self._stream_infer(self._session, prompt,
request_output_len,
sequence_start,
sequence_end):
yield status, res, tokens
self._session.prev = self._session.prev + self._session.round_prev
def end(self, session_id: int, *args, **kwargs):
"""end a session. Triton inference server will release the session's
occupied resource when it is ended.
Args:
session_id (int): the identical id of a session
Returns:
int: 0: success, -1: session not found
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'end session: {session_id}')
if self._session is None:
logger.error(
f"session {session_id} doesn't exist. It cannot be ended")
return StatusCode.TRITON_SESSION_INVALID_ARG
if self._session.session_id != session_id:
logger.error(f'you cannot end session {session_id}, because this '
f'session is {self._session.session_id}')
return StatusCode.TRITON_SESSION_INVALID_ARG
if self._session.status == 0:
logger.warning(f'session {session_id} has already been ended')
return StatusCode.TRITON_SESSION_CLOSED
self._session.status = 0
for status, _, _ in self._stream_infer(self._session,
prompt='',
request_output_len=0,
sequence_start=False,
sequence_end=True):
if status != StatusCode.TRITON_STREAM_END:
return status
return StatusCode.TRITON_STREAM_END
def cancel(self, session_id: int, *args, **kwargs):
"""Cancel the session during generating tokens.
Args:
session_id (int): the identical id of a session
Returns:
int: 0: success, -1: session not found
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'cancel session: {session_id}')
if self._session is None:
logger.error(
f"session {session_id} doesn't exist. It cannot be cancelled")
return StatusCode.TRITON_SESSION_INVALID_ARG
if self._session.session_id != session_id:
logger.error(
f'you cannot cancel session {session_id}, because this '
f'session is {self._session.session_id}')
return StatusCode.TRITON_SESSION_INVALID_ARG
if self._session.status == 0:
logger.error(f'session {session_id} has already been ended. '
f'It cannot be cancelled')
return StatusCode.TRITON_SESSION_CLOSED
prev_session = self._session
for status, res, _ in self._stream_infer(self._session,
prompt='',
request_output_len=0,
sequence_start=False,
sequence_end=False,
cancel=True):
if status.value < 0:
break
if status == StatusCode.TRITON_STREAM_END:
logger.info(f'cancel session {session_id} successfully')
if prev_session.prev:
logger.warn(f'TODO: start to recover session {session_id}')
else:
logger.info(f'cancel session {session_id} failed: {res}')
return status
def _get_prompt(self, prompt: str, sequence_start: bool):
if self.model_name == 'vicuna':
if sequence_start:
return f'USER: {prompt} ASSISTANT:'
else:
return f'</s>USER: {prompt} ASSISTANT:'
def _stream_infer(self,
session: Session,
prompt: str,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
cancel: bool = False):
logger = get_logger(log_level=self.log_level)
logger.info(f'session {session.session_id}, '
f'request id {session.request_id}, '
f'request_output_len {request_output_len}, '
f'start {sequence_start}, '
f'end {sequence_end}, cancel {cancel}')
assert request_output_len is None or \
isinstance(request_output_len, int), \
f'request_output_len is supposed to be None or int, ' \
f'but got {type(request_output_len)}'
input_ids, input_lengths = self.preprocess(prompt)
input_tokens = input_lengths.squeeze()
if request_output_len is None:
request_output_len = max(
128,
self.cfg.session_len - session.sequence_length - input_tokens)
if input_tokens + request_output_len + \
session.sequence_length > self.cfg.session_len:
errmsg = f'session {session.session_id}, ' \
f'out of max sequence length {self.cfg.session_len}, ' \
f'#input tokens {input_tokens}, ' \
f'history tokens {session.sequence_length}, ' \
f'request length {request_output_len}'
yield StatusCode.TRITON_SESSION_OUT_OF_LIMIT, errmsg, 0
logger.info(f'session {session.session_id}, '
f'input tokens: {input_tokens}, '
f'request tokens: {request_output_len}, '
f'history tokens: {session.sequence_length}')
preseq_length = session.sequence_length
session.round_prev = ''
que = queue.Queue()
producer = threading.Thread(target=self._stream_producer,
args=(self.tritonserver_addr, session, que,
self.cfg, input_ids, input_lengths,
request_output_len, sequence_start,
sequence_end, preseq_length, cancel))
producer.start()
for state, res, tokens in self.stream_consumer(self.postprocess, que,
session, preseq_length,
cancel, logger,
self.display):
if state.value < 0:
yield state, res, 0
else:
yield state, res, tokens - input_tokens
producer.join()
self._session = que.get()
curseq_length = self._session.sequence_length
logger.info(f'session {session.session_id}, pre seq_len '
f'{preseq_length}, cur seq_len {curseq_length}, '
f'diff {curseq_length - preseq_length}')
@staticmethod
def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
input_lengths, request_output_len, sequence_start,
sequence_end, preseq_length, cancel):
request_output_len = np.full(input_lengths.shape,
request_output_len).astype(np.uint32)
callback = partial(stream_callback, que)
with grpcclient.InferenceServerClient(tritonserver_addr) as client:
inputs = [
prepare_tensor('input_ids', input_ids),
prepare_tensor('input_lengths', input_lengths),
prepare_tensor('request_output_len', request_output_len),
prepare_tensor('runtime_top_k',
cfg.top_k * np.ones((1, 1), dtype=np.uint32)),
prepare_tensor('runtime_top_p',
cfg.top_p * np.ones((1, 1), dtype=np.float32)),
prepare_tensor(
'temperature',
cfg.temperature * np.ones((1, 1), dtype=np.float32)),
prepare_tensor(
'repetition_penalty',
cfg.repetition_penalty * np.ones(
(1, 1), dtype=np.float32)),
prepare_tensor('step',
preseq_length * np.ones((1, 1), dtype=np.int32))
]
if cfg.stop_words is not None:
inputs += [prepare_tensor('stop_words_list', cfg.stop_words)]
if cfg.bad_words is not None:
inputs += [prepare_tensor('bad_words_list', cfg.bad_words)]
inputs += [
prepare_tensor(
'session_len',
cfg.session_len *
np.ones([input_ids.shape[0], 1], dtype=np.uint32)),
prepare_tensor('START', (1 if sequence_start else 0) * np.ones(
(1, 1), dtype=np.int32)),
prepare_tensor('END', (1 if sequence_end else 0) * np.ones(
(1, 1), dtype=np.int32)),
prepare_tensor(
'CORRID',
session.session_id * np.ones((1, 1), dtype=np.uint64)),
prepare_tensor('STOP', (1 if cancel else 0) * np.ones(
(1, 1), dtype=np.int32))
]
if sequence_start:
random_seed = random.getrandbits(64)
inputs += [
prepare_tensor(
'random_seed',
random_seed * np.ones((1, 1), dtype=np.uint64))
]
client.start_stream(callback)
client.async_stream_infer('fastertransformer',
inputs,
sequence_id=session.session_id,
request_id=session.request_id,
sequence_start=sequence_start,
sequence_end=sequence_end)
que.put(None)
@staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel,
logger, display):
def process_response(res):
if session.ai_says is None:
return res, True
index = res.find(session.ai_says)
if index == -1:
return res, False
res = res[index + len(session.ai_says):].replace(session.eoa, '')
return res, True
while True:
result = res_queue.get()
if result is None:
yield StatusCode.TRITON_STREAM_END, session.response, \
session.sequence_length - preseq_length
break
if 'errcode' in result:
logger.error(f'got error from fastertransformer, code '
f"{result['errcode']}, {result['errmsg']}, "
f'token {session.sequence_length}')
session.sequence_length = preseq_length
yield result['errcode'], result['errmsg'], 0
break
if cancel:
continue
try:
message = ModelInferResponse()
google.protobuf.json_format.Parse(json.dumps(result), message)
result = grpcclient.InferResult(message)
sequence_length = result.as_numpy('sequence_length')
output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - preseq_length
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
sequence_length = sequence_length.reshape(
(1, sequence_length.shape[-1]))
output_str = postprocess(output_ids[:, :, preseq_length:],
sequence_length)
text = output_str[0].decode()
if display:
new_text = text[len(session.round_prev):]
print(new_text, end='', flush=True)
session.round_prev = text
yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze())
except Exception as e:
logger.error(f'catch exception: {e}')
# put session back to queue so that `_stream_infer` can update it in
# `self.sessions`
while not res_queue.empty():
res_queue.get()
res_queue.put(session)
if display:
print('\n')
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Union
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
logger_initialized = {}
def get_logger(name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode)
return logger
except Exception:
pass
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
def prepare_tensor(name, input_tensor):
t = grpcclient.InferInput(name, list(input_tensor.shape),
np_to_triton_dtype(input_tensor.dtype))
t.set_data_from_numpy(input_tensor)
return t
class Preprocessor:
def __init__(self, tritonserver_addr: str):
self.tritonserver_addr = tritonserver_addr
self.model_name = 'preprocessing'
def __call__(self, *args, **kwargs):
return self.infer(*args, **kwargs)
def infer(self, prompts: Union[str, List[str]]) -> tuple:
"""Tokenize the input prompts.
Args:
prompts(str | List[str]): user's prompt, or a batch prompts
Returns:
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
ids, ids' length and requested output length
"""
if isinstance(prompts, str):
input0 = [[prompts]]
elif isinstance(prompts, List):
input0 = [[prompt] for prompt in prompts]
else:
assert 0, f'str or List[str] prompts are expected but got ' \
f'{type(prompts)}'
input0_data = np.array(input0).astype(object)
output0_len = np.ones_like(input0).astype(np.uint32)
inputs = [
prepare_tensor('QUERY', input0_data),
prepare_tensor('REQUEST_OUTPUT_LEN', output0_len)
]
with grpcclient.InferenceServerClient(self.tritonserver_addr) as \
client:
result = client.infer(self.model_name, inputs)
output0 = result.as_numpy('INPUT_ID')
output1 = result.as_numpy('REQUEST_INPUT_LEN')
return output0, output1
class Postprocessor:
def __init__(self, tritonserver_addr: str):
self.tritonserver_addr = tritonserver_addr
def __call__(self, *args, **kwargs):
return self.infer(*args, **kwargs)
def infer(self, output_ids: np.ndarray, seqlen: np.ndarray):
"""De-tokenize tokens for text.
Args:
output_ids(np.ndarray): tokens' id
seqlen(np.ndarray): sequence length
Returns:
str: decoded tokens
"""
inputs = [
prepare_tensor('TOKENS_BATCH', output_ids),
prepare_tensor('sequence_length', seqlen)
]
inputs[0].set_data_from_numpy(output_ids)
inputs[1].set_data_from_numpy(seqlen)
model_name = 'postprocessing'
with grpcclient.InferenceServerClient(self.tritonserver_addr) \
as client:
result = client.infer(model_name, inputs)
output0 = result.as_numpy('OUTPUT')
return output0
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
__version__ = '0.0.1'
short_version = __version__
def parse_version_info(version_str: str) -> Tuple:
"""Parse version from a string.
Args:
version_str (str): A string represents a version info.
Returns:
tuple: A sequence of integer and string represents version.
"""
_version_info = []
for x in version_str.split('.'):
if x.isdigit():
_version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
_version_info.append(int(patch_version[0]))
_version_info.append(f'rc{patch_version[1]}')
return tuple(_version_info)
version_info = parse_version_info(__version__)
__all__ = ['__version__', 'version_info', 'parse_version_info']
import os
from setuptools import find_packages, setup
pwd = os.path.dirname(__file__)
version_file = 'llmdeploy/version.py'
def readme():
with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
content = f.read()
return content
def get_version():
with open(os.path.join(pwd, version_file), 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
if __name__ == '__main__':
setup(name='llmdeploy',
version=get_version(),
description='triton inference service of llama',
long_description=readme(),
long_description_content_type='text/markdown',
author='OpenMMLab',
author_email='openmmlab@gmail.com',
packages=find_packages(
exclude=('llmdeploy/serve/fastertransformer/triton_models', )),
classifiers=[
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment