Unverified Commit 3de0dbb6 authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

Add non-stream inference api for chatbot (#200)

* add non-stream inference api for chatbot

* update according to reviewer's comments
parent b7e7e668
...@@ -13,7 +13,9 @@ def input_prompt(): ...@@ -13,7 +13,9 @@ def input_prompt():
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def main(tritonserver_addr: str, session_id: int = 1): def main(tritonserver_addr: str,
session_id: int = 1,
stream_output: bool = True):
"""An example to communicate with inference server through the command line """An example to communicate with inference server through the command line
interface. interface.
...@@ -21,9 +23,12 @@ def main(tritonserver_addr: str, session_id: int = 1): ...@@ -21,9 +23,12 @@ def main(tritonserver_addr: str, session_id: int = 1):
tritonserver_addr (str): the address in format "ip:port" of tritonserver_addr (str): the address in format "ip:port" of
triton inference server triton inference server
session_id (int): the identical id of a session session_id (int): the identical id of a session
stream_output (bool): indicator for streaming output or not
""" """
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING') log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
chatbot = Chatbot(tritonserver_addr, log_level=log_level, display=True) chatbot = Chatbot(tritonserver_addr,
log_level=log_level,
display=stream_output)
nth_round = 1 nth_round = 1
while True: while True:
prompt = input_prompt() prompt = input_prompt()
...@@ -33,12 +38,19 @@ def main(tritonserver_addr: str, session_id: int = 1): ...@@ -33,12 +38,19 @@ def main(tritonserver_addr: str, session_id: int = 1):
chatbot.end(session_id) chatbot.end(session_id)
else: else:
request_id = f'{session_id}-{nth_round}' request_id = f'{session_id}-{nth_round}'
for status, res, n_token in chatbot.stream_infer( if stream_output:
session_id, for status, res, n_token in chatbot.stream_infer(
prompt, session_id,
request_id=request_id, prompt,
request_output_len=512): request_id=request_id,
continue request_output_len=512):
continue
else:
status, res, n_token = chatbot.infer(session_id,
prompt,
request_id=request_id,
request_output_len=512)
print(res)
nth_round += 1 nth_round += 1
......
...@@ -294,6 +294,65 @@ class Chatbot: ...@@ -294,6 +294,65 @@ class Chatbot:
self._session.histories = histories self._session.histories = histories
return status return status
def infer(self,
session_id: int,
prompt: str,
request_id: str = '',
request_output_len: int = None,
sequence_start: bool = False,
sequence_end: bool = False,
*args,
**kwargs):
"""Start a new round conversion of a session. Return the chat
completions in non-stream mode.
Args:
session_id (int): the identical id of a session
prompt (str): user's prompt in this round conversation
request_id (str): the identical id of this round conversation
request_output_len (int): the expected generated token numbers
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'request_output_len {request_output_len}')
if self._session is None:
sequence_start = True
self._session = Session(session_id=session_id)
elif self._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return StatusCode.TRITON_SESSION_CLOSED, '', 0
self._session.status = 1
self._session.request_id = request_id
self._session.response = ''
self._session.prompt = self._get_prompt(prompt, sequence_start)
status, res, tokens = None, '', 0
for status, res, tokens in self._stream_infer(self._session,
self._session.prompt,
request_output_len,
sequence_start,
sequence_end):
if status.value < 0:
break
if status.value == 0:
self._session.histories = \
self._session.histories + self._session.prompt + \
self._session.response
return status, res, tokens
else:
return status, res, tokens
def reset_session(self): def reset_session(self):
"""reset session.""" """reset session."""
self._session = None self._session = None
......
...@@ -32,13 +32,17 @@ def valid_str(string, coding='utf-8'): ...@@ -32,13 +32,17 @@ def valid_str(string, coding='utf-8'):
def main(model_path, def main(model_path,
session_id: int = 1, session_id: int = 1,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
tp=1): tp=1,
stream_output=True):
"""An example to perform model inference through the command line """An example to perform model inference through the command line
interface. interface.
Args: Args:
model_path (str): the path of the deployed model model_path (str): the path of the deployed model
session_id (int): the identical id of a session session_id (int): the identical id of a session
repetition_penalty (float): parameter to penalize repetition
tp (int): GPU number used in tensor parallelism
stream_output (bool): indicator for streaming output or not
""" """
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path) tokenizer = Tokenizer(tokenizer_model_path)
...@@ -62,7 +66,8 @@ def main(model_path, ...@@ -62,7 +66,8 @@ def main(model_path,
input_ids=[input_ids], input_ids=[input_ids],
request_output_len=512, request_output_len=512,
sequence_start=False, sequence_start=False,
sequence_end=True): sequence_end=True,
stream_output=stream_output):
pass pass
nth_round = 1 nth_round = 1
step = 0 step = 0
...@@ -80,7 +85,7 @@ def main(model_path, ...@@ -80,7 +85,7 @@ def main(model_path,
for outputs in generator.stream_infer( for outputs in generator.stream_infer(
session_id=session_id, session_id=session_id,
input_ids=[input_ids], input_ids=[input_ids],
stream_output=True, stream_output=stream_output,
request_output_len=512, request_output_len=512,
sequence_start=(nth_round == 1), sequence_start=(nth_round == 1),
sequence_end=False, sequence_end=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment