Unverified Commit 289ffa3c authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

fix the offset during streaming chat (#142)

parent 79595cd1
......@@ -113,7 +113,7 @@ conversation""" # noqa: E501
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'<bos>{self.system}\n' \
return f'<BOS>{self.system}\n' \
f'{self.user}:{prompt}{self.eoh}\n' \
f'{self.assistant}:'
else:
......
......@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
chatbot.end(session_id)
else:
request_id = f'{session_id}-{nth_round}'
for status, res, tokens in chatbot.stream_infer(
for status, res, n_token in chatbot.stream_infer(
session_id,
prompt,
request_id=request_id,
......
......@@ -381,13 +381,16 @@ class Chatbot:
request_output_len, sequence_start,
sequence_end, preseq_length, cancel))
producer.start()
for state, res, tokens in self.stream_consumer(
self.postprocess, que, session, preseq_length, cancel, logger,
self.display, self.profile_generation, self.eos_id):
for state, res, tokens in self.stream_consumer(self.postprocess, que,
session, input_tokens,
preseq_length, cancel,
logger, self.display,
self.profile_generation,
self.eos_id):
if state.value < 0:
yield state, res, 0
else:
yield state, res, tokens - input_tokens
yield state, res, tokens
producer.join()
self._session = que.get()
curseq_length = self._session.sequence_length
......@@ -477,8 +480,9 @@ class Chatbot:
que.put(None)
@staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel,
logger, display, profile_generation, eos_id):
def stream_consumer(postprocess, res_queue, session, n_input_token,
preseq_length, cancel, logger, display,
profile_generation, eos_id):
"""Consume the response from the triton inference server.
Args:
......@@ -486,6 +490,7 @@ class Chatbot:
the generated tokens
res_queue (multiprocessing.Queue): response queue
session (Session): an instance of a session
n_input_token (int): token number of input prompt
preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session
logger (util.Logger):
......@@ -496,12 +501,12 @@ class Chatbot:
Yields:
tuple: status, text, generated token number
"""
offset = n_input_token + preseq_length
while True:
result = res_queue.get()
if result is None:
yield StatusCode.TRITON_STREAM_END, \
session.response[len(session.prompt):], \
session.sequence_length - preseq_length
yield (StatusCode.TRITON_STREAM_END, session.response,
session.sequence_length - offset)
session.status = StatusCode.TRITON_STREAM_END
break
if 'errcode' in result:
......@@ -521,7 +526,7 @@ class Chatbot:
output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - preseq_length
sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1
......@@ -536,23 +541,15 @@ class Chatbot:
'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze())
continue
output_str = postprocess(output_ids[:, :, preseq_length:],
output_str = postprocess(output_ids[:, :, offset:],
sequence_length)
text = output_str[0].decode()
if display:
if len(text) > len(session.prompt):
if session.status == StatusCode.TRITON_SESSION_READY:
new_text = text[len(session.prompt):]
session.status = StatusCode.TRITON_STREAM_ING
else:
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text
if len(session.response) > len(session.prompt):
session.status = StatusCode.TRITON_STREAM_ING
yield (StatusCode.TRITON_STREAM_ING,
session.response[len(session.prompt):],
sequence_length.squeeze())
yield (StatusCode.TRITON_STREAM_ING, session.response,
sequence_length.squeeze())
except Exception as e:
logger.error(f'catch exception: {e}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment