Unverified Commit 3de0dbb6 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

Add non-stream inference api for chatbot (#200)

* add non-stream inference api for chatbot

* update according to reviewer's comments
parent b7e7e668
......@@ -13,7 +13,9 @@ def input_prompt():
return '\n'.join(iter(input, sentinel))
def main(tritonserver_addr: str, session_id: int = 1):
def main(tritonserver_addr: str,
session_id: int = 1,
stream_output: bool = True):
"""An example to communicate with inference server through the command line
interface.
......@@ -21,9 +23,12 @@ def main(tritonserver_addr: str, session_id: int = 1):
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
session_id (int): the identical id of a session
stream_output (bool): indicator for streaming output or not
"""
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
chatbot = Chatbot(tritonserver_addr, log_level=log_level, display=True)
chatbot = Chatbot(tritonserver_addr,
log_level=log_level,
display=stream_output)
nth_round = 1
while True:
prompt = input_prompt()
......@@ -33,12 +38,19 @@ def main(tritonserver_addr: str, session_id: int = 1):
chatbot.end(session_id)
else:
request_id = f'{session_id}-{nth_round}'
if stream_output:
for status, res, n_token in chatbot.stream_infer(
session_id,
prompt,
request_id=request_id,
request_output_len=512):
continue
else:
status, res, n_token = chatbot.infer(session_id,
prompt,
request_id=request_id,
request_output_len=512)
print(res)
nth_round += 1
......
......@@ -294,6 +294,65 @@ class Chatbot:
self._session.histories = histories
return status
def infer(self,
session_id: int,
prompt: str,
request_id: str = '',
request_output_len: int = None,
sequence_start: bool = False,
sequence_end: bool = False,
*args,
**kwargs):
"""Start a new round conversion of a session. Return the chat
completions in non-stream mode.
Args:
session_id (int): the identical id of a session
prompt (str): user's prompt in this round conversation
request_id (str): the identical id of this round conversation
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}')
if self._session is None:
sequence_start = True
self._session = Session(session_id=session_id)
elif self._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return StatusCode.TRITON_SESSION_CLOSED, '', 0
self._session.status = 1
self._session.request_id = request_id
self._session.response = ''
self._session.prompt = self._get_prompt(prompt, sequence_start)
status, res, tokens = None, '', 0
for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self._session.histories = \
self._session.histories + self._session.prompt + \
self._session.response
return status, res, tokens
else:
return status, res, tokens
def reset_session(self):
"""reset session."""
self._session = None
......
......@@ -32,13 +32,17 @@ def valid_str(string, coding='utf-8'):
def main(model_path,
session_id: int = 1,
repetition_penalty: float = 1.0,
tp=1):
tp=1,
stream_output=True):
"""An example to perform model inference through the command line
interface.
Args:
model_path (str): the path of the deployed model
session_id (int): the identical id of a session
repetition_penalty (float): parameter to penalize repetition
tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
......@@ -62,7 +66,8 @@ def main(model_path,
input_ids=[input_ids],
request_output_len=512,
sequence_start=False,
sequence_end=True):
sequence_end=True,
stream_output=stream_output):
pass
nth_round = 1
step = 0
......@@ -80,7 +85,7 @@ def main(model_path,
for outputs in generator.stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=True,
stream_output=stream_output,
request_output_len=512,
sequence_start=(nth_round == 1),
sequence_end=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment