Unverified Commit 3a7880a8 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Fix token count bug (#416)

* fix token count bug

* fix error response
parent d44a8bfe
......@@ -57,9 +57,10 @@ def create_error_response(status: HTTPStatus, message: str):
status (HTTPStatus): HTTP status codes and reason phrases
message (str): error message
"""
return JSONResponse(ErrorResponse(message=message,
type='invalid_request_error').dict(),
status_code=status.value)
return JSONResponse(
ErrorResponse(message=message,
type='invalid_request_error',
code=status.value).dict())
async def check_request(request) -> Optional[JSONResponse]:
......@@ -117,7 +118,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
result_generator = VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
request.stream,
True, # always use stream to enable batching
request.renew_session,
request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop,
......@@ -130,7 +131,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for _ in VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
request.stream,
True,
request.renew_session,
stop=True):
pass
......@@ -188,6 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
# Non-streaming response
final_res = None
text = ''
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
......@@ -195,11 +197,12 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choices = []
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=final_res.response),
message=ChatMessage(role='assistant', content=text),
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
......@@ -308,7 +311,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
finish_reason = None
async for out in generation:
text += out.response
tokens += out.generate_token_len
tokens = out.generate_token_len
finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
return JSONResponse(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment