Unverified Commit b728064e authored by MaxMatthew's avatar MaxMatthew Committed by GitHub
Browse files

remove slicing reponse and add resume api (#154)

* Fix lmdeploy.serve.turbomind bug
* add __init__.py for turbomind
* add resume function
* fix the assignment for session.response

* Fix code style
parent f07b697b
# Copyright (c) OpenMMLab. All rights reserved.
...@@ -259,6 +259,47 @@ class Chatbot: ...@@ -259,6 +259,47 @@ class Chatbot:
logger.info(f'cancel session {session_id} failed: {res}') logger.info(f'cancel session {session_id} failed: {res}')
return status return status
def resume(self, session_id: int, *args, **kwargs):
"""Resume a session by sending the history conversations to triton
inference server. After resuming, users can continue chatting with
chatbot.
Args:
session_id (int): the identical id of a session
Returns:
int: 0: success, -1: session not found
"""
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
logger = get_logger(log_level=self.log_level)
logger.info(f'resume session: {session_id}')
if self._session is None:
logger.error(
f"session {session_id} doesn't exist. It cannot be recovered")
return StatusCode.TRITON_SESSION_INVALID_ARG
if self._session.session_id != session_id:
logger.error(
f'you cannot resume session {session_id}, because this '
f'session is {self._session.session_id}')
return StatusCode.TRITON_SESSION_INVALID_ARG
self._session.status = 1
self._session.sequence_length = 0
histories = self._session.histories
for status, _, _ in self._stream_infer(self._session,
prompt=histories,
request_output_len=0,
sequence_start=True,
sequence_end=False):
if status.value < 0:
return status
self._session.histories = histories
return status
def reset_session(self): def reset_session(self):
"""reset session.""" """reset session."""
self._session = None self._session = None
...@@ -553,7 +594,6 @@ class Chatbot: ...@@ -553,7 +594,6 @@ class Chatbot:
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
session.response = session.response[len(session.prompt):]
# put session back to queue so that `_stream_infer` can update it in # put session back to queue so that `_stream_infer` can update it in
# `self.sessions` # `self.sessions`
while not res_queue.empty(): while not res_queue.empty():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment