Unverified Commit 3a7880a8 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Fix token count bug (#416)

* fix token count bug

* fix error response
parent d44a8bfe
...@@ -57,9 +57,10 @@ def create_error_response(status: HTTPStatus, message: str): ...@@ -57,9 +57,10 @@ def create_error_response(status: HTTPStatus, message: str):
status (HTTPStatus): HTTP status codes and reason phrases status (HTTPStatus): HTTP status codes and reason phrases
message (str): error message message (str): error message
""" """
return JSONResponse(ErrorResponse(message=message, return JSONResponse(
type='invalid_request_error').dict(), ErrorResponse(message=message,
status_code=status.value) type='invalid_request_error',
code=status.value).dict())
async def check_request(request) -> Optional[JSONResponse]: async def check_request(request) -> Optional[JSONResponse]:
...@@ -117,7 +118,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -117,7 +118,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
result_generator = VariableInterface.async_engine.generate_openai( result_generator = VariableInterface.async_engine.generate_openai(
request.messages, request.messages,
instance_id, instance_id,
request.stream, True, # always use stream to enable batching
request.renew_session, request.renew_session,
request_output_len=request.max_tokens if request.max_tokens else 512, request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop, stop=request.stop,
...@@ -130,7 +131,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -130,7 +131,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for _ in VariableInterface.async_engine.generate_openai( async for _ in VariableInterface.async_engine.generate_openai(
request.messages, request.messages,
instance_id, instance_id,
request.stream, True,
request.renew_session, request.renew_session,
stop=True): stop=True):
pass pass
...@@ -188,6 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -188,6 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
# Non-streaming response # Non-streaming response
final_res = None final_res = None
text = ''
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
...@@ -195,11 +197,12 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -195,11 +197,12 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected') 'Client disconnected')
final_res = res final_res = res
text += res.response
assert final_res is not None assert final_res is not None
choices = [] choices = []
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role='assistant', content=final_res.response), message=ChatMessage(role='assistant', content=text),
finish_reason=final_res.finish_reason, finish_reason=final_res.finish_reason,
) )
choices.append(choice_data) choices.append(choice_data)
...@@ -308,7 +311,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -308,7 +311,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
finish_reason = None finish_reason = None
async for out in generation: async for out in generation:
text += out.response text += out.response
tokens += out.generate_token_len tokens = out.generate_token_len
finish_reason = out.finish_reason finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason} ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
return JSONResponse(ret) return JSONResponse(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment