Unverified Commit 686f5e32 authored by Iskren Ivov Chernev's avatar Iskren Ivov Chernev Committed by GitHub
Browse files

Return usage for openai streaming requests (#1663)

parent 415d1095
...@@ -245,6 +245,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -245,6 +245,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
index: int, index: int,
text: str, text: str,
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str: ) -> str:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
...@@ -257,7 +258,10 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -257,7 +258,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
model=model_name, model=model_name,
choices=[choice_data], choices=[choice_data],
) )
response_json = response.json(ensure_ascii=False) if usage is not None:
response.usage = usage
# exclude unset to leave details out of each sse
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json return response_json
...@@ -283,17 +287,25 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -283,17 +287,25 @@ async def create_chat_completion(request: ChatCompletionRequest,
i = output.index i = output.index
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text=delta_text, text=delta_text,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
usage=final_usage,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -462,6 +474,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -462,6 +474,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
text: str, text: str,
logprobs: Optional[LogProbs] = None, logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str: ) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
...@@ -475,7 +488,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -475,7 +488,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
model=model_name, model=model_name,
choices=[choice_data], choices=[choice_data],
) )
response_json = response.json(ensure_ascii=False) if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json return response_json
...@@ -505,11 +520,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -505,11 +520,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = (LogProbs() logprobs = (LogProbs()
if request.logprobs is not None else None) if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
usage=final_usage,
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
......
...@@ -139,6 +139,7 @@ class CompletionStreamResponse(BaseModel): ...@@ -139,6 +139,7 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo]
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
...@@ -178,3 +179,5 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -178,3 +179,5 @@ class ChatCompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(
default=None, description="data about request and response")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment