Unverified Commit 289ffa3c authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

fix the offset during streaming chat (#142)

parent 79595cd1
...@@ -113,7 +113,7 @@ conversation""" # noqa: E501 ...@@ -113,7 +113,7 @@ conversation""" # noqa: E501
def get_prompt(self, prompt, sequence_start=True): def get_prompt(self, prompt, sequence_start=True):
if sequence_start: if sequence_start:
return f'<bos>{self.system}\n' \ return f'<BOS>{self.system}\n' \
f'{self.user}:{prompt}{self.eoh}\n' \ f'{self.user}:{prompt}{self.eoh}\n' \
f'{self.assistant}:' f'{self.assistant}:'
else: else:
......
...@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1): ...@@ -37,7 +37,7 @@ def main(tritonserver_addr: str, model_name: str, session_id: int = 1):
chatbot.end(session_id) chatbot.end(session_id)
else: else:
request_id = f'{session_id}-{nth_round}' request_id = f'{session_id}-{nth_round}'
for status, res, tokens in chatbot.stream_infer( for status, res, n_token in chatbot.stream_infer(
session_id, session_id,
prompt, prompt,
request_id=request_id, request_id=request_id,
......
...@@ -381,13 +381,16 @@ class Chatbot: ...@@ -381,13 +381,16 @@ class Chatbot:
request_output_len, sequence_start, request_output_len, sequence_start,
sequence_end, preseq_length, cancel)) sequence_end, preseq_length, cancel))
producer.start() producer.start()
for state, res, tokens in self.stream_consumer( for state, res, tokens in self.stream_consumer(self.postprocess, que,
self.postprocess, que, session, preseq_length, cancel, logger, session, input_tokens,
self.display, self.profile_generation, self.eos_id): preseq_length, cancel,
logger, self.display,
self.profile_generation,
self.eos_id):
if state.value < 0: if state.value < 0:
yield state, res, 0 yield state, res, 0
else: else:
yield state, res, tokens - input_tokens yield state, res, tokens
producer.join() producer.join()
self._session = que.get() self._session = que.get()
curseq_length = self._session.sequence_length curseq_length = self._session.sequence_length
...@@ -477,8 +480,9 @@ class Chatbot: ...@@ -477,8 +480,9 @@ class Chatbot:
que.put(None) que.put(None)
@staticmethod @staticmethod
def stream_consumer(postprocess, res_queue, session, preseq_length, cancel, def stream_consumer(postprocess, res_queue, session, n_input_token,
logger, display, profile_generation, eos_id): preseq_length, cancel, logger, display,
profile_generation, eos_id):
"""Consume the response from the triton inference server. """Consume the response from the triton inference server.
Args: Args:
...@@ -486,6 +490,7 @@ class Chatbot: ...@@ -486,6 +490,7 @@ class Chatbot:
the generated tokens the generated tokens
res_queue (multiprocessing.Queue): response queue res_queue (multiprocessing.Queue): response queue
session (Session): an instance of a session session (Session): an instance of a session
n_input_token (int): token number of input prompt
preseq_length (int): the history sequence length preseq_length (int): the history sequence length
cancel (bool): indicator for cancelling the session cancel (bool): indicator for cancelling the session
logger (util.Logger): logger (util.Logger):
...@@ -496,12 +501,12 @@ class Chatbot: ...@@ -496,12 +501,12 @@ class Chatbot:
Yields: Yields:
tuple: status, text, generated token number tuple: status, text, generated token number
""" """
offset = n_input_token + preseq_length
while True: while True:
result = res_queue.get() result = res_queue.get()
if result is None: if result is None:
yield StatusCode.TRITON_STREAM_END, \ yield (StatusCode.TRITON_STREAM_END, session.response,
session.response[len(session.prompt):], \ session.sequence_length - offset)
session.sequence_length - preseq_length
session.status = StatusCode.TRITON_STREAM_END session.status = StatusCode.TRITON_STREAM_END
break break
if 'errcode' in result: if 'errcode' in result:
...@@ -521,7 +526,7 @@ class Chatbot: ...@@ -521,7 +526,7 @@ class Chatbot:
output_ids = result.as_numpy('output_ids') output_ids = result.as_numpy('output_ids')
session.sequence_length = sequence_length.squeeze() session.sequence_length = sequence_length.squeeze()
sequence_length = sequence_length - preseq_length sequence_length = sequence_length - offset
last_token_id = output_ids[-1][-1][session.sequence_length - 1] last_token_id = output_ids[-1][-1][session.sequence_length - 1]
if last_token_id == eos_id: if last_token_id == eos_id:
session.sequence_length = session.sequence_length - 1 session.sequence_length = session.sequence_length - 1
...@@ -536,23 +541,15 @@ class Chatbot: ...@@ -536,23 +541,15 @@ class Chatbot:
'postprocessing is ignored during profiling ' 'postprocessing is ignored during profiling '
'token generation', sequence_length.squeeze()) 'token generation', sequence_length.squeeze())
continue continue
output_str = postprocess(output_ids[:, :, preseq_length:], output_str = postprocess(output_ids[:, :, offset:],
sequence_length) sequence_length)
text = output_str[0].decode() text = output_str[0].decode()
if display: if display:
if len(text) > len(session.prompt): new_text = text[len(session.response):]
if session.status == StatusCode.TRITON_SESSION_READY: print(new_text, end='', flush=True)
new_text = text[len(session.prompt):]
session.status = StatusCode.TRITON_STREAM_ING
else:
new_text = text[len(session.response):]
print(new_text, end='', flush=True)
session.response = text session.response = text
if len(session.response) > len(session.prompt): yield (StatusCode.TRITON_STREAM_ING, session.response,
session.status = StatusCode.TRITON_STREAM_ING sequence_length.squeeze())
yield (StatusCode.TRITON_STREAM_ING,
session.response[len(session.prompt):],
sequence_length.squeeze())
except Exception as e: except Exception as e:
logger.error(f'catch exception: {e}') logger.error(f'catch exception: {e}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment