Unverified Commit 334543ff authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add continuous_usage_stats support for streaming responses (#12241)

parent c143f416
......@@ -109,6 +109,7 @@ class UsageInfo(BaseModel):
class StreamOptions(BaseModel):
include_usage: Optional[bool] = False
continuous_usage_stats: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
......
......@@ -535,6 +535,17 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
# Add usage stats if continuous_usage_stats is enabled
if (
request.stream_options
and request.stream_options.continuous_usage_stats
):
chunk.usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens.get(index, 0),
completion_tokens=completion_tokens.get(index, 0),
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Handle tool calls
......@@ -579,6 +590,17 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
# Add usage stats if continuous_usage_stats is enabled
if (
request.stream_options
and request.stream_options.continuous_usage_stats
):
chunk.usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens.get(index, 0),
completion_tokens=completion_tokens.get(index, 0),
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Send finish_reason chunks for each index that completed
......@@ -1056,6 +1078,16 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
# Add usage stats if continuous_usage_stats is enabled
if request.stream_options and request.stream_options.continuous_usage_stats:
prompt_tokens = content["meta_info"].get("prompt_tokens", 0)
completion_tokens = content["meta_info"].get("completion_tokens", 0)
chunk.usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
......@@ -1096,6 +1128,16 @@ class OpenAIServingChat(OpenAIServingBase):
choices=[choice_data],
model=request.model,
)
# Add usage stats if continuous_usage_stats is enabled
if request.stream_options and request.stream_options.continuous_usage_stats:
prompt_tokens = content["meta_info"].get("prompt_tokens", 0)
completion_tokens = content["meta_info"].get("completion_tokens", 0)
chunk.usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
yield f"data: {chunk.model_dump_json()}\n\n"
def _check_for_unstreamed_tool_args(
......
......@@ -272,6 +272,16 @@ class OpenAIServingCompletion(OpenAIServingBase):
model=request.model,
)
# Add usage stats if continuous_usage_stats is enabled
if (
request.stream_options
and request.stream_options.continuous_usage_stats
):
chunk.usage = UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens.get(index, 0),
completion_tokens=completion_tokens.get(index, 0),
)
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
......
import asyncio
import unittest
import openai
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestContinuousUsageStats(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_continuous_usage_stats_enabled(self):
stream = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "What is machine learning?"}],
stream=True,
max_tokens=30,
temperature=0,
stream_options={"include_usage": True, "continuous_usage_stats": True},
)
chunks_with_usage = 0
chunks_with_content = 0
last_usage = None
for chunk in stream:
has_content = len(chunk.choices) > 0 and chunk.choices[0].delta.content
if chunk.usage:
chunks_with_usage += 1
last_usage = chunk.usage
if has_content:
chunks_with_content += 1
assert chunks_with_content > 0
assert chunks_with_usage >= chunks_with_content
assert last_usage.prompt_tokens > 0
assert last_usage.completion_tokens > 0
assert (
last_usage.total_tokens
== last_usage.prompt_tokens + last_usage.completion_tokens
)
async def test_continuous_usage_stats_async(self):
stream = await self.aclient.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "What is deep learning?"}],
stream=True,
max_tokens=30,
temperature=0,
stream_options={"include_usage": True, "continuous_usage_stats": True},
)
chunks_with_usage = 0
chunks_with_content = 0
async for chunk in stream:
has_content = len(chunk.choices) > 0 and chunk.choices[0].delta.content
if chunk.usage:
chunks_with_usage += 1
if has_content:
chunks_with_content += 1
assert chunks_with_content > 0
assert chunks_with_usage >= chunks_with_content
def test_continuous_usage_stats_disabled(self):
stream = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": "What is AI?"}],
stream=True,
max_tokens=30,
temperature=0,
stream_options={"include_usage": True, "continuous_usage_stats": False},
)
usage_chunks = []
for chunk in stream:
if chunk.usage:
usage_chunks.append(chunk)
assert len(usage_chunks) == 1
assert len(usage_chunks[0].choices) == 0
def test_async_runner(self):
asyncio.run(self.test_continuous_usage_stats_async())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment