Unverified Commit 955c019c authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

add docstring for turbomind (#97)

* add docstring

* update

* update

* fix according to review results
parent b6dc35fe
...@@ -36,6 +36,14 @@ def chat_stream(instruction: str, ...@@ -36,6 +36,14 @@ def chat_stream(instruction: str,
state_chatbot: Sequence, state_chatbot: Sequence,
llama_chatbot: Chatbot, llama_chatbot: Chatbot,
model_name: str = None): model_name: str = None):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
llama_chatbot (Chatbot): the instance of a chatbot
model_name (str): the name of deployed model
"""
bot_summarized_response = '' bot_summarized_response = ''
model_type = 'turbomind' model_type = 'turbomind'
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
...@@ -61,7 +69,7 @@ def chat_stream(instruction: str, ...@@ -61,7 +69,7 @@ def chat_stream(instruction: str,
def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
llama_chatbot: gr.State, triton_server_addr: str, llama_chatbot: gr.State, triton_server_addr: str,
model_name: str): model_name: str):
"""reset the session."""
state_chatbot = [] state_chatbot = []
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
llama_chatbot = Chatbot(triton_server_addr, llama_chatbot = Chatbot(triton_server_addr,
...@@ -82,6 +90,7 @@ def cancel_func( ...@@ -82,6 +90,7 @@ def cancel_func(
state_chatbot: gr.State, state_chatbot: gr.State,
llama_chatbot: gr.State, llama_chatbot: gr.State,
): ):
"""cancel the session."""
session_id = llama_chatbot._session.session_id session_id = llama_chatbot._session.session_id
llama_chatbot.cancel(session_id) llama_chatbot.cancel(session_id)
...@@ -95,6 +104,14 @@ def run(triton_server_addr: str, ...@@ -95,6 +104,14 @@ def run(triton_server_addr: str,
model_name: str, model_name: str,
server_name: str = 'localhost', server_name: str = 'localhost',
server_port: int = 6006): server_port: int = 6006):
"""chat with AI assistant through web ui.
Args:
triton_server_addr (str): the communication address of inference server
model_name (str): the name of the deployed model
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
"""
with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Blocks(css=CSS, theme=THEME) as demo:
chat_interface = partial(chat_stream, model_name=model_name) chat_interface = partial(chat_stream, model_name=model_name)
reset_all = partial(reset_all_func, reset_all = partial(reset_all_func,
......
...@@ -6,6 +6,7 @@ MODELS = Registry('model', locations=['lmdeploy.model']) ...@@ -6,6 +6,7 @@ MODELS = Registry('model', locations=['lmdeploy.model'])
@MODELS.register_module(name='vicuna') @MODELS.register_module(name='vicuna')
class Vicuna: class Vicuna:
"""Chat template of vicuna model."""
def __init__(self): def __init__(self):
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501 self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501
...@@ -13,6 +14,16 @@ class Vicuna: ...@@ -13,6 +14,16 @@ class Vicuna:
self.assistant = 'ASSISTANT' self.assistant = 'ASSISTANT'
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
if sequence_start: if sequence_start:
return f'{self.system} {self.user}: {prompt} {self.assistant}:' return f'{self.system} {self.user}: {prompt} {self.assistant}:'
else: else:
...@@ -20,11 +31,13 @@ class Vicuna: ...@@ -20,11 +31,13 @@ class Vicuna:
@property @property
def stop_words(self): def stop_words(self):
"""Return the stop-words' token ids."""
return None return None
@MODELS.register_module(name='internlm') @MODELS.register_module(name='internlm')
class InternLM: class InternLM:
"""Chat template of InternLM model."""
def __init__(self): def __init__(self):
self.system = '' self.system = ''
...@@ -34,6 +47,16 @@ class InternLM: ...@@ -34,6 +47,16 @@ class InternLM:
self.assistant = '<|Bot|>' self.assistant = '<|Bot|>'
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
if sequence_start: if sequence_start:
return f'{self.system}\n' \ return f'{self.system}\n' \
f'{self.user}:{prompt}{self.eoh}\n' \ f'{self.user}:{prompt}{self.eoh}\n' \
...@@ -44,20 +67,33 @@ class InternLM: ...@@ -44,20 +67,33 @@ class InternLM:
@property @property
def stop_words(self): def stop_words(self):
"""Return the stop-words' token ids."""
return [103027, 103028] return [103027, 103028]
@MODELS.register_module(name='llama') @MODELS.register_module(name='llama')
class Llama: class Llama:
"""Chat template of LLaMA model."""
def __init__(self): def __init__(self):
pass pass
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
return prompt return prompt
@property @property
def stop_words(self): def stop_words(self):
"""Return the stop-words' token ids."""
return None return None
......
...@@ -7,12 +7,22 @@ from lmdeploy.serve.turbomind.chatbot import Chatbot ...@@ -7,12 +7,22 @@ from lmdeploy.serve.turbomind.chatbot import Chatbot
def input_prompt(): def input_prompt():
"""Input a prompt in the console interface."""
print('\ndouble enter to end input >>> ', end='') print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def main(tritonserver_addr: str, model_name: str, session_id: int = 1): def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
"""An example to communicate with inference server through the command line
interface.
Args:
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
model_name (str): the name of the deployed model
session_id (int): the identical id of a session
"""
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
chatbot = Chatbot(tritonserver_addr, chatbot = Chatbot(tritonserver_addr,
model_name, model_name,
......
...@@ -42,6 +42,7 @@ class StatusCode(Enum): ...@@ -42,6 +42,7 @@ class StatusCode(Enum):
def stream_callback(que, result, error): def stream_callback(que, result, error):
"""callback function invoked by triton client."""
if error: if error:
print(error) print(error)
que.put(dict(errcode=StatusCode.TRITON_SERVER_ERR, errmsg=f'{error}')) que.put(dict(errcode=StatusCode.TRITON_SERVER_ERR, errmsg=f'{error}'))
...@@ -50,6 +51,7 @@ def stream_callback(que, result, error): ...@@ -50,6 +51,7 @@ def stream_callback(que, result, error):
def get_logger(log_file=None, log_level=logging.INFO): def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
from .utils import get_logger from .utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level) logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger return logger
...@@ -258,17 +260,21 @@ class Chatbot: ...@@ -258,17 +260,21 @@ class Chatbot:
return status return status
def reset_session(self): def reset_session(self):
"""reset session."""
self._session = None self._session = None
def _get_bos(self): def _get_bos(self):
"""return bos token id."""
token_ids, _ = self.preprocess('<BOS>') token_ids, _ = self.preprocess('<BOS>')
return token_ids[0][0] return token_ids[0][0]
def _get_eos(self): def _get_eos(self):
"""return eos token id."""
token_ids, _ = self.preprocess('<EOS>') token_ids, _ = self.preprocess('<EOS>')
return token_ids[0][0] return token_ids[0][0]
def _stop_words(self, stop_words: List[int]): def _stop_words(self, stop_words: List[int]):
"""return stop-words' token ids."""
if stop_words is None: if stop_words is None:
return None return None
assert isinstance(stop_words, List) and \ assert isinstance(stop_words, List) and \
...@@ -283,6 +289,8 @@ class Chatbot: ...@@ -283,6 +289,8 @@ class Chatbot:
return stop_words return stop_words
def _get_prompt(self, prompt: str, sequence_start: bool): def _get_prompt(self, prompt: str, sequence_start: bool):
"""return the concatenated prompt according to the model's chat
template."""
if self.profile_generation or self.profile_serving: if self.profile_generation or self.profile_serving:
return prompt return prompt
return self.model.get_prompt(prompt, sequence_start) return self.model.get_prompt(prompt, sequence_start)
...@@ -294,6 +302,19 @@ class Chatbot: ...@@ -294,6 +302,19 @@ class Chatbot:
sequence_start: bool = True, sequence_start: bool = True,
sequence_end: bool = False, sequence_end: bool = False,
cancel: bool = False): cancel: bool = False):
"""communicate with inference server to chat, or cancel a session, or
end a session.
Args:
session (Session): an instance of a session
prompt (str): the concatenated prompt
request_output_len (int): the max number of tokens to be generated
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
cancel (bool): indicator for cancelling the session
Yields:
tuple: status, text, generated token number
"""
logger = get_logger(log_level=self.log_level) logger = get_logger(log_level=self.log_level)
logger.info(f'session {session.session_id}, ' logger.info(f'session {session.session_id}, '
f'request id {session.request_id}, ' f'request id {session.request_id}, '
...@@ -368,6 +389,22 @@ class Chatbot: ...@@ -368,6 +389,22 @@ class Chatbot:
def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, def _stream_producer(tritonserver_addr, session, que, cfg, input_ids,
input_lengths, request_output_len, sequence_start, input_lengths, request_output_len, sequence_start,
sequence_end, preseq_length, cancel): sequence_end, preseq_length, cancel):
"""Send a request to the triton inference server.
Args:
tritonserver_addr (str): the communication address of the inference
server
session (Session): an instance of a session
que (multiprocessing.Queue): response queue
cfg:
input_ids (numpy.ndarray): token ids of input prompt
input_lengths (numpy.ndarray): length of input_ids
request_output_len (int): the max number of tokens to be generated
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session
"""
request_output_len = np.full(input_lengths.shape, request_output_len = np.full(input_lengths.shape,
request_output_len).astype(np.uint32) request_output_len).astype(np.uint32)
...@@ -432,7 +469,23 @@ class Chatbot: ...@@ -432,7 +469,23 @@ class Chatbot:
@staticmethod @staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel, def stream_consumer(postprocess, res_queue, session, preseq_length, cancel,
logger, display, profile_generation, eos_id): logger, display, profile_generation, eos_id):
"""Consume the response from the triton inference server.
Args:
postprocess (callable): postprocess function for
the generated tokens
res_queue (multiprocessing.Queue): response queue
session (Session): an instance of a session
preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session
logger (util.Logger):
display (bool): display the text in the consolo interface or not
profile_generation (bool): indicator for profiling token generation
eos_id (int): eos token id
Yields:
tuple: status, text, generated token number
"""
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
......
...@@ -16,6 +16,13 @@ supported_formats = ['llama', 'hf'] ...@@ -16,6 +16,13 @@ supported_formats = ['llama', 'hf']
def create_workspace(_path: str): def create_workspace(_path: str):
"""Create a workspace.
Args:
_path (str): the path of the workspace
Returns:
bool: success or not
"""
try: try:
if osp.exists(_path): if osp.exists(_path):
shutil.rmtree(_path) shutil.rmtree(_path)
...@@ -28,6 +35,13 @@ def create_workspace(_path: str): ...@@ -28,6 +35,13 @@ def create_workspace(_path: str):
def destroy_workspace(_path: str): def destroy_workspace(_path: str):
"""destroy workspace.
Args:
_path(str): the path of the workspace
Returns:
bool: success or not
"""
try: try:
shutil.rmtree(_path) shutil.rmtree(_path)
print(f'destroy workspace in directory {_path}') print(f'destroy workspace in directory {_path}')
...@@ -38,6 +52,13 @@ def destroy_workspace(_path: str): ...@@ -38,6 +52,13 @@ def destroy_workspace(_path: str):
def copy_triton_model_templates(_path: str): def copy_triton_model_templates(_path: str):
"""copy triton model templates to the specified path.
Args:
_path (str): the target path
Returns:
str: the path of the triton models
"""
try: try:
cur_path = osp.abspath(__file__) cur_path = osp.abspath(__file__)
dir_path = osp.dirname(cur_path) dir_path = osp.dirname(cur_path)
...@@ -55,6 +76,13 @@ def copy_triton_model_templates(_path: str): ...@@ -55,6 +76,13 @@ def copy_triton_model_templates(_path: str):
def tokenizer_info(model_path: str): def tokenizer_info(model_path: str):
"""Return the vocabulary size, bos token id and eos token id.
Args:
model_path (str): the tokenizer model's path
Returns:
tuple: vocabulary size, bos token id and eos token id
"""
assert os.path.isfile(model_path), model_path assert os.path.isfile(model_path), model_path
sp_model = SentencePieceProcessor(model_file=model_path) sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs # BOS / EOS token IDs
...@@ -72,6 +100,18 @@ def export(model_name: str, ...@@ -72,6 +100,18 @@ def export(model_name: str,
out_dir: str, out_dir: str,
tp: int, tp: int,
size_per_head: int = 128): size_per_head: int = 128):
"""Export deploying information to a config file.
Args:
model_name (str): model's name
num_layer (int): the number of transformer blocks
norm_eps (float): norm epsilon
model_params (dict): parameters of a model
tokenizer_path (str): the tokenizer model's path
out_dir (str): the path of the output directory
tp (int): the number of tensor parallelism
size_per_head (int): the dimension of each head
"""
out_dir = osp.join(out_dir, 'weights') out_dir = osp.join(out_dir, 'weights')
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
...@@ -163,6 +203,16 @@ def export(model_name: str, ...@@ -163,6 +203,16 @@ def export(model_name: str,
def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int): triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
"""
if osp.exists(tokenizer_path): if osp.exists(tokenizer_path):
shutil.copy(tokenizer_path, shutil.copy(tokenizer_path,
osp.join(triton_models_path, 'tokenizer/tokenizer.model')) osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
...@@ -269,13 +319,18 @@ def permute(x: torch.Tensor): ...@@ -269,13 +319,18 @@ def permute(x: torch.Tensor):
1).transpose(1, 2).reshape(dim, 1) 1).transpose(1, 2).reshape(dim, 1)
def check_zero(x: torch.Tensor):
_sum = x.flatten().sum().item()
assert _sum == 0, str(_sum)
def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int): triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
"""
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = osp.join(model_path, 'tokenizer.model') tokenizer_path = osp.join(model_path, 'tokenizer.model')
if osp.exists(tokenizer_path): if osp.exists(tokenizer_path):
...@@ -318,9 +373,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -318,9 +373,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
_params.update(_tmp) _params.update(_tmp)
def get_tensor(name): def get_tensor(name):
"""return tensor according its name."""
return _params[name] return _params[name]
def get_tensor_transposed(name: str): def get_tensor_transposed(name: str):
"""return a transposed tensor according its name."""
if name not in _params and name.find('bias'): if name not in _params and name.find('bias'):
return None return None
return _params[name].t() return _params[name].t()
...@@ -407,6 +464,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -407,6 +464,11 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
def pack_model_repository(workspace_path: str): def pack_model_repository(workspace_path: str):
"""package the model repository.
Args:
workspace_path: the path of workspace
"""
model_repo_dir = osp.join(workspace_path, 'model_repository') model_repo_dir = osp.join(workspace_path, 'model_repository')
os.makedirs(model_repo_dir, exist_ok=True) os.makedirs(model_repo_dir, exist_ok=True)
os.symlink(src=osp.join('../triton_models/interactive'), os.symlink(src=osp.join('../triton_models/interactive'),
......
...@@ -9,6 +9,11 @@ import triton_python_backend_utils as pb_utils ...@@ -9,6 +9,11 @@ import triton_python_backend_utils as pb_utils
class Tokenizer: class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Args:
model_file (str): the path of the tokenizer model
"""
def __init__(self, model_file: str): def __init__(self, model_file: str):
model_folder = osp.split(model_file)[0] model_folder = osp.split(model_file)[0]
...@@ -38,6 +43,13 @@ class Tokenizer: ...@@ -38,6 +43,13 @@ class Tokenizer:
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
if not self.use_hf_model: if not self.use_hf_model:
add_bos = False add_bos = False
add_eos = False add_eos = False
...@@ -59,6 +71,13 @@ class Tokenizer: ...@@ -59,6 +71,13 @@ class Tokenizer:
return self.model.encode(s, add_special_tokens=add_special_tokens) return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: List[int]): def decode(self, t: List[int]):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
Returns:
str: text of decoding tokens
"""
if not self.use_hf_model: if not self.use_hf_model:
return self.model.Decode(t) return self.model.Decode(t)
else: else:
...@@ -173,6 +192,7 @@ class TritonPythonModel: ...@@ -173,6 +192,7 @@ class TritonPythonModel:
print('Cleaning up...') print('Cleaning up...')
def _postprocessing(self, tokens_batch, sequence_length): def _postprocessing(self, tokens_batch, sequence_length):
"""decode token ids into texts."""
outputs = [] outputs = []
for beam_tokens, beam_len in zip(tokens_batch, sequence_length): for beam_tokens, beam_len in zip(tokens_batch, sequence_length):
for tokens, _len in zip(beam_tokens, beam_len): for tokens, _len in zip(beam_tokens, beam_len):
......
...@@ -11,6 +11,11 @@ from torch.nn.utils.rnn import pad_sequence ...@@ -11,6 +11,11 @@ from torch.nn.utils.rnn import pad_sequence
class Tokenizer: class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Args:
model_file (str): the path of the tokenizer model
"""
def __init__(self, model_file: str): def __init__(self, model_file: str):
model_folder = osp.split(model_file)[0] model_folder = osp.split(model_file)[0]
...@@ -40,6 +45,13 @@ class Tokenizer: ...@@ -40,6 +45,13 @@ class Tokenizer:
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
if not self.use_hf_model: if not self.use_hf_model:
add_bos = False add_bos = False
add_eos = False add_eos = False
...@@ -61,6 +73,13 @@ class Tokenizer: ...@@ -61,6 +73,13 @@ class Tokenizer:
return self.model.encode(s, add_special_tokens=add_special_tokens) return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: List[int]): def decode(self, t: List[int]):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
Returns:
str: text of decoding tokens
"""
if not self.use_hf_model: if not self.use_hf_model:
return self.model.Decode(t) return self.model.Decode(t)
else: else:
...@@ -187,6 +206,13 @@ class TritonPythonModel: ...@@ -187,6 +206,13 @@ class TritonPythonModel:
print('Cleaning up...') print('Cleaning up...')
def _create_request(self, query): def _create_request(self, query):
"""Tokenize prompts and return the token ids and their length.
Args:
query (List[str]): a list of prompt
Returns:
tuple: token ids and their length
"""
start_ids = [ start_ids = [
torch.IntTensor(self.tokenizer.encode(s[0].decode())) torch.IntTensor(self.tokenizer.encode(s[0].decode()))
for s in query for s in query
......
...@@ -85,6 +85,7 @@ def get_logger(name: str, ...@@ -85,6 +85,7 @@ def get_logger(name: str,
def prepare_tensor(name, input_tensor): def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
t = grpcclient.InferInput(name, list(input_tensor.shape), t = grpcclient.InferInput(name, list(input_tensor.shape),
np_to_triton_dtype(input_tensor.dtype)) np_to_triton_dtype(input_tensor.dtype))
t.set_data_from_numpy(input_tensor) t.set_data_from_numpy(input_tensor)
...@@ -92,6 +93,12 @@ def prepare_tensor(name, input_tensor): ...@@ -92,6 +93,12 @@ def prepare_tensor(name, input_tensor):
class Preprocessor: class Preprocessor:
"""Tokenize prompts.
Args:
tritonserver_addr (str): the communication address of the inference
server
"""
def __init__(self, tritonserver_addr: str): def __init__(self, tritonserver_addr: str):
self.tritonserver_addr = tritonserver_addr self.tritonserver_addr = tritonserver_addr
...@@ -134,6 +141,12 @@ class Preprocessor: ...@@ -134,6 +141,12 @@ class Preprocessor:
class Postprocessor: class Postprocessor:
"""De-tokenize prompts.
Args:
tritonserver_addr (str): the communication address of the inference
server
"""
def __init__(self, tritonserver_addr: str): def __init__(self, tritonserver_addr: str):
self.tritonserver_addr = tritonserver_addr self.tritonserver_addr = tritonserver_addr
......
...@@ -13,12 +13,14 @@ os.environ['TM_LOG_LEVEL'] = 'ERROR' ...@@ -13,12 +13,14 @@ os.environ['TM_LOG_LEVEL'] = 'ERROR'
def input_prompt(): def input_prompt():
"""Input a prompt in the consolo interface."""
print('\ndouble enter to end input >>> ', end='') print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def valid_str(string, coding='utf-8'): def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd'] invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding) bstr = bytes(string, coding)
for invalid_char in invalid_chars: for invalid_char in invalid_chars:
...@@ -28,6 +30,14 @@ def valid_str(string, coding='utf-8'): ...@@ -28,6 +30,14 @@ def valid_str(string, coding='utf-8'):
def main(model_name, model_path, session_id: int = 1): def main(model_name, model_path, session_id: int = 1):
"""An example to perform model inference through the command line
interface.
Args:
model_name (str): the name of the deployed model
model_path (str): the path of the deployed model
session_id (int): the identical id of a session
"""
model = MODELS.get(model_name)() model = MODELS.get(model_name)()
tm_model = tm.TurboMind(model_path, stop_words=model.stop_words) tm_model = tm.TurboMind(model_path, stop_words=model.stop_words)
generator = tm_model.create_instance() generator = tm_model.create_instance()
......
...@@ -7,6 +7,11 @@ from torch.nn.utils.rnn import pad_sequence ...@@ -7,6 +7,11 @@ from torch.nn.utils.rnn import pad_sequence
class Tokenizer: class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Args:
model_file (str): the path of the tokenizer model
"""
def __init__(self, model_file: str): def __init__(self, model_file: str):
if model_file.endswith('.model'): if model_file.endswith('.model'):
...@@ -41,6 +46,13 @@ class Tokenizer: ...@@ -41,6 +46,13 @@ class Tokenizer:
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns:
list[int]: token ids
"""
if not self.use_hf_model: if not self.use_hf_model:
add_bos = False add_bos = False
add_eos = False add_eos = False
...@@ -62,6 +74,13 @@ class Tokenizer: ...@@ -62,6 +74,13 @@ class Tokenizer:
return self.model.encode(s, add_special_tokens=add_special_tokens) return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: Sequence[int]): def decode(self, t: Sequence[int]):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
Returns:
str: text of decoding tokens
"""
if not self.use_hf_model: if not self.use_hf_model:
return self.model.Decode(t) return self.model.Decode(t)
else: else:
...@@ -71,6 +90,11 @@ class Tokenizer: ...@@ -71,6 +90,11 @@ class Tokenizer:
class Preprocessor: class Preprocessor:
"""Tokenize prompts.
Args:
tokenizer (Tokenizer): an instance of tokenizer
"""
def __init__(self, tokenizer: Tokenizer): def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -110,6 +134,11 @@ class Preprocessor: ...@@ -110,6 +134,11 @@ class Preprocessor:
class Postprocessor: class Postprocessor:
"""De-tokenize token ids.
Args:
tokenizer (Tokenizer): an instance of tokenizer
"""
def __init__(self, tokenizer: Tokenizer): def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer self.tokenizer = tokenizer
......
...@@ -18,6 +18,7 @@ import _turbomind as _tm # noqa: E402 ...@@ -18,6 +18,7 @@ import _turbomind as _tm # noqa: E402
def _stop_words(stop_words: List[int]): def _stop_words(stop_words: List[int]):
"""return list of stop-words to numpy.ndarray."""
if stop_words is None: if stop_words is None:
return None return None
assert isinstance(stop_words, List) and \ assert isinstance(stop_words, List) and \
...@@ -33,6 +34,7 @@ def _stop_words(stop_words: List[int]): ...@@ -33,6 +34,7 @@ def _stop_words(stop_words: List[int]):
def _np_dict_to_tm_dict(np_dict: dict): def _np_dict_to_tm_dict(np_dict: dict):
"""map numpy.ndarray to turbomind's tensor."""
ret = _tm.TensorMap() ret = _tm.TensorMap()
for k, v in np_dict.items(): for k, v in np_dict.items():
ret[k] = _tm.from_dlpack(v) ret[k] = _tm.from_dlpack(v)
...@@ -41,6 +43,7 @@ def _np_dict_to_tm_dict(np_dict: dict): ...@@ -41,6 +43,7 @@ def _np_dict_to_tm_dict(np_dict: dict):
def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap): def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
"""map turbomind's tensor to torch's tensor."""
ret = dict() ret = dict()
for k, v in tm_dict.items(): for k, v in tm_dict.items():
if v.type == _tm.DataType.TYPE_UINT32: if v.type == _tm.DataType.TYPE_UINT32:
...@@ -51,6 +54,19 @@ def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap): ...@@ -51,6 +54,19 @@ def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
class TurboMind: class TurboMind:
"""LMDeploy's inference engine.
Args:
model_path (str): the path of turbomind's model
data_type (str): the data type
session_len (int): the max length of a session
eos_id (int): eos token id
stop_words (List[int]): token ids of stop-words
device_id (int): the id of a gpu card
node_id (int): the id of a node
device_num (int): the number of gpu cards
node_num (int): the number of node
"""
def __init__(self, def __init__(self,
model_path: str, model_path: str,
...@@ -81,10 +97,23 @@ class TurboMind: ...@@ -81,10 +97,23 @@ class TurboMind:
self.stop_words = _stop_words(stop_words) self.stop_words = _stop_words(stop_words)
def create_instance(self, cuda_stream_id=0): def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
Args:
cuda_stream_id(int): identity of a cuda stream
Returns:
TurboMindInstance: an instance of turbomind
"""
return TurboMindInstance(self, cuda_stream_id) return TurboMindInstance(self, cuda_stream_id)
class TurboMindInstance: class TurboMindInstance:
"""Instance of TurboMind.
Args:
tm_model (str): turbomind's model path
cuda_stream_id(int): identity of a cuda stream
"""
def __init__(self, tm_model, cuda_stream_id=0): def __init__(self, tm_model, cuda_stream_id=0):
self.tm_model = tm_model self.tm_model = tm_model
...@@ -138,7 +167,28 @@ class TurboMindInstance: ...@@ -138,7 +167,28 @@ class TurboMindInstance:
ignore_eos=False, ignore_eos=False,
random_seed=None, random_seed=None,
stream_output=False): stream_output=False):
"""Perform model inference.
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
step (int): the offset of the k/v cache
stop (bool): indicator for cancelling the session
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
random_seed (int): seed used by sampling
stream_output (bool): indicator for stream output
"""
if stream_output: if stream_output:
self.model_inst.register_callback(self._forward_callback) self.model_inst.register_callback(self._forward_callback)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment