Unverified Commit dea2b84b authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[OAI Server Refactor] [ChatCompletions & Completions] Implement UsageInfo Processor (#7360)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent cfb2fb5a
......@@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request):
pass
@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name
return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
# Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs
......
......@@ -114,33 +114,6 @@ class OpenAIServingBase(ABC):
"""Validate request"""
pass
def _calculate_streaming_usage_base(
self,
prompt_tokens: Dict[int, int],
completion_tokens: Dict[int, int],
cached_tokens: Dict[int, int],
n_choices: int,
) -> UsageInfo:
"""Calculate usage information for streaming responses (common logic)"""
total_prompt_tokens = sum(
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
)
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
cache_report = self.tokenizer_manager.server_args.enable_cache_report
prompt_tokens_details = None
if cache_report:
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
if cached_tokens_sum > 0:
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
return UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def create_error_response(
self,
message: str,
......
......@@ -26,8 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
detect_template_content_format,
process_content_for_template_format,
to_openai_style_logprobs,
......@@ -546,11 +546,12 @@ class OpenAIServingChat(OpenAIServingBase):
# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
......@@ -658,7 +659,9 @@ class OpenAIServingChat(OpenAIServingBase):
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)
return ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
......
......@@ -18,10 +18,8 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
to_openai_style_logprobs,
)
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__)
......@@ -214,11 +212,12 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base(
usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens,
completion_tokens,
cached_tokens,
request.n,
n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
)
final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
......@@ -322,7 +321,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report)
usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)
return CompletionResponse(
id=ret[0]["meta_info"]["id"],
......
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, final
from python.sglang.srt.entrypoints.openai.protocol import UsageInfo
@final
class UsageProcessor:
"""Stateless helpers that turn raw token counts into a UsageInfo."""
@staticmethod
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
return {"cached_tokens": count} if count > 0 else None
@staticmethod
def calculate_response_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)
cached_details = None
if enable_cache_report:
cached_total = sum(
r["meta_info"].get("cached_tokens", 0) for r in responses
)
cached_details = UsageProcessor._details_if_cached(cached_total)
return UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_details,
)
@staticmethod
def calculate_streaming_usage(
prompt_tokens: Mapping[int, int],
completion_tokens: Mapping[int, int],
cached_tokens: Mapping[int, int],
n_choices: int,
enable_cache_report: bool = False,
) -> UsageInfo:
# index % n_choices == 0 marks the first choice of a prompt
total_prompt_tokens = sum(
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
)
total_completion_tokens = sum(completion_tokens.values())
cached_details = (
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
if enable_cache_report
else None
)
return UsageProcessor.calculate_token_usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
cached_tokens=cached_details,
)
@staticmethod
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
import logging
from typing import Any, Dict, List, Optional
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
from sglang.srt.entrypoints.openai.protocol import LogProbs
logger = logging.getLogger(__name__)
......@@ -171,62 +170,6 @@ def process_content_for_template_format(
return new_msg
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
def aggregate_token_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
"""Aggregate token usage from multiple responses
Args:
responses: List of response dictionaries with meta_info
n_choices: Number of choices per request (for prompt token counting)
enable_cache_report: Whether to include cached token details
Returns:
Aggregated UsageInfo
"""
# Sum completion tokens from all responses
completion_tokens = sum(
response["meta_info"]["completion_tokens"] for response in responses
)
# For prompt tokens, only count every n_choices-th response to avoid double counting
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)
# Handle cached tokens if cache reporting is enabled
cached_tokens_details = None
if enable_cache_report:
cached_tokens_sum = sum(
response["meta_info"].get("cached_tokens", 0) for response in responses
)
if cached_tokens_sum > 0:
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
return calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_tokens_details,
)
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment