Unverified Commit 68296844 authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Fix TIS client got-no-space-result side effect brought by PR #197 (#222)

* rollback

* rollback chatbot.py
parent af517a4a
...@@ -26,7 +26,6 @@ class Session: ...@@ -26,7 +26,6 @@ class Session:
request_id: str = '' request_id: str = ''
histories: str = '' # history conversations of the session histories: str = '' # history conversations of the session
sequence_length: int = 0 # the total generated token number in the session sequence_length: int = 0 # the total generated token number in the session
sequence_offset: int = 0 # the new generated token offset in the session
prompt: str = '' prompt: str = ''
response: str = '' response: str = ''
status: int = None # status of the session status: int = None # status of the session
...@@ -599,15 +598,14 @@ class Chatbot: ...@@ -599,15 +598,14 @@ class Chatbot:
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
session.sequence_offset = n_input_token + preseq_length offset = n_input_token + preseq_length
sentinel = n_input_token + preseq_length
status, res, n_token = None, '', 0 status, res, n_token = None, '', 0
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
status = StatusCode.TRITON_STREAM_END status = StatusCode.TRITON_STREAM_END
res = session.response res = session.response
n_token = session.sequence_length - sentinel n_token = session.sequence_length - offset
session.status = StatusCode.TRITON_STREAM_END session.status = StatusCode.TRITON_STREAM_END
break break
if 'errcode' in result: if 'errcode' in result:
...@@ -630,31 +628,30 @@ class Chatbot: ...@@ -630,31 +628,30 @@ class Chatbot:
output_ids = result.as_numpy('output_ids') output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze() session.sequence_length = sequence_length.squeeze()
new_token_length = sequence_length - session.sequence_offset sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1] last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id: if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1 session.sequence_length = session.sequence_length - 1
new_token_length = new_token_length - 1 sequence_length = sequence_length - 1
output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) output_ids = output_ids.reshape((1, 1, output_ids.shape[-1]))
new_token_length = new_token_length.reshape( sequence_length = sequence_length.reshape(
(1, new_token_length.shape[-1])) (1, sequence_length.shape[-1]))
if profile_generation: if profile_generation:
yield (StatusCode.TRITON_STREAM_ING, yield (StatusCode.TRITON_STREAM_ING,
'postprocessing is ignored during profiling ' 'postprocessing is ignored during profiling '
'token generation', new_token_length.squeeze()) 'token generation', sequence_length.squeeze())
continue continue
output_str = postprocess( output_str = postprocess(output_ids[:, :, offset:],
output_ids[:, :, session.sequence_offset:], sequence_length)
new_token_length)
session.sequence_offset = session.sequence_length
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
print(text, end='', flush=True) new_text = text[len(session.response):]
session.response += text print(new_text, end='', flush=True)
session.response = text
yield (StatusCode.TRITON_STREAM_ING, session.response, yield (StatusCode.TRITON_STREAM_ING, session.response,
session.sequence_offset - sentinel) sequence_length.squeeze())
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
......
...@@ -23,7 +23,7 @@ output [ ...@@ -23,7 +23,7 @@ output [
instance_group [ instance_group [
{ {
count: 1 count: 16
kind: KIND_CPU kind: KIND_CPU
} }
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment