Unverified Commit 16b4b823 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

fix `finish_reason` (#816)

parent a5b67b95
...@@ -204,7 +204,7 @@ class AsyncEngine: ...@@ -204,7 +204,7 @@ class AsyncEngine:
if do_preprocess: if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start) prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None finish_reason = None
if self.id2step[str(session_id)] + len( if self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len: input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length' finish_reason = 'length'
...@@ -247,9 +247,12 @@ class AsyncEngine: ...@@ -247,9 +247,12 @@ class AsyncEngine:
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size = tokens response_size = tokens
finish_reason = 'length' \
if tokens >= request_output_len else 'stop'
# `response_size` might be note updated since # `response_size` might be note updated since
# ` if response.endswith('�')` # ` if response.endswith('�')`
if response_size != tokens: if response_size == tokens:
response = '' # avaid returning the last response twice
yield GenOut(response, self.id2step[str(session_id)], yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
# update step # update step
......
...@@ -179,6 +179,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -179,6 +179,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=0, index=0,
text=res.response, text=res.response,
finish_reason=res.finish_reason,
) )
yield f'data: {response_json}\n\n' yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
...@@ -329,6 +330,7 @@ async def completions_v1(request: CompletionRequest, ...@@ -329,6 +330,7 @@ async def completions_v1(request: CompletionRequest,
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=0, index=0,
text=res.response, text=res.response,
finish_reason=res.finish_reason,
) )
yield f'data: {response_json}\n\n' yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment