"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7c788f531c94dc00d9577102b9599100334b2ba0"
Unverified Commit dea2b84b authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[OAI Server Refactor] [ChatCompletions & Completions] Implement UsageInfo Processor (#7360)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent cfb2fb5a
...@@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request): ...@@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request):
pass pass
@app.api_route("/v1/models/{model_id}", methods=["GET"])
async def show_model_detail(model_id: str):
served_model_name = app.state.tokenizer_manager.served_model_name
return ModelCard(
id=served_model_name,
root=served_model_name,
max_model_len=app.state.tokenizer_manager.model_config.context_len,
)
# Additional API endpoints will be implemented in separate serving_*.py modules # Additional API endpoints will be implemented in separate serving_*.py modules
# and mounted as APIRouters in future PRs # and mounted as APIRouters in future PRs
......
...@@ -114,33 +114,6 @@ class OpenAIServingBase(ABC): ...@@ -114,33 +114,6 @@ class OpenAIServingBase(ABC):
"""Validate request""" """Validate request"""
pass pass
def _calculate_streaming_usage_base(
self,
prompt_tokens: Dict[int, int],
completion_tokens: Dict[int, int],
cached_tokens: Dict[int, int],
n_choices: int,
) -> UsageInfo:
"""Calculate usage information for streaming responses (common logic)"""
total_prompt_tokens = sum(
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
)
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
cache_report = self.tokenizer_manager.server_args.enable_cache_report
prompt_tokens_details = None
if cache_report:
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
if cached_tokens_sum > 0:
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
return UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def create_error_response( def create_error_response(
self, self,
message: str, message: str,
......
...@@ -26,8 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -26,8 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
TopLogprob, TopLogprob,
) )
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import ( from sglang.srt.entrypoints.openai.utils import (
aggregate_token_usage,
detect_template_content_format, detect_template_content_format,
process_content_for_template_format, process_content_for_template_format,
to_openai_style_logprobs, to_openai_style_logprobs,
...@@ -546,11 +546,12 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -546,11 +546,12 @@ class OpenAIServingChat(OpenAIServingBase):
# Additional usage chunk # Additional usage chunk
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base( usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
cached_tokens, cached_tokens,
request.n, n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
) )
usage_chunk = ChatCompletionStreamResponse( usage_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
...@@ -658,7 +659,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -658,7 +659,9 @@ class OpenAIServingChat(OpenAIServingBase):
# Calculate usage # Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report) usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)
return ChatCompletionResponse( return ChatCompletionResponse(
id=ret[0]["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
......
...@@ -18,10 +18,8 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -18,10 +18,8 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
) )
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.utils import ( from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
aggregate_token_usage, from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -214,11 +212,12 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -214,11 +212,12 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Handle final usage chunk # Handle final usage chunk
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
usage = self._calculate_streaming_usage_base( usage = UsageProcessor.calculate_streaming_usage(
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
cached_tokens, cached_tokens,
request.n, n_choices=request.n,
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
) )
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
...@@ -322,7 +321,9 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -322,7 +321,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
# Calculate usage # Calculate usage
cache_report = self.tokenizer_manager.server_args.enable_cache_report cache_report = self.tokenizer_manager.server_args.enable_cache_report
usage = aggregate_token_usage(ret, request.n, cache_report) usage = UsageProcessor.calculate_response_usage(
ret, n_choices=request.n, enable_cache_report=cache_report
)
return CompletionResponse( return CompletionResponse(
id=ret[0]["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
......
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, final
from python.sglang.srt.entrypoints.openai.protocol import UsageInfo
@final
class UsageProcessor:
"""Stateless helpers that turn raw token counts into a UsageInfo."""
@staticmethod
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
return {"cached_tokens": count} if count > 0 else None
@staticmethod
def calculate_response_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)
cached_details = None
if enable_cache_report:
cached_total = sum(
r["meta_info"].get("cached_tokens", 0) for r in responses
)
cached_details = UsageProcessor._details_if_cached(cached_total)
return UsageProcessor.calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_details,
)
@staticmethod
def calculate_streaming_usage(
prompt_tokens: Mapping[int, int],
completion_tokens: Mapping[int, int],
cached_tokens: Mapping[int, int],
n_choices: int,
enable_cache_report: bool = False,
) -> UsageInfo:
# index % n_choices == 0 marks the first choice of a prompt
total_prompt_tokens = sum(
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
)
total_completion_tokens = sum(completion_tokens.values())
cached_details = (
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
if enable_cache_report
else None
)
return UsageProcessor.calculate_token_usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
cached_tokens=cached_details,
)
@staticmethod
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
import logging import logging
from typing import Any, Dict, List, Optional
import jinja2.nodes import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo from sglang.srt.entrypoints.openai.protocol import LogProbs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -171,62 +170,6 @@ def process_content_for_template_format( ...@@ -171,62 +170,6 @@ def process_content_for_template_format(
return new_msg return new_msg
def calculate_token_usage(
prompt_tokens: int,
completion_tokens: int,
cached_tokens: Optional[Dict[str, int]] = None,
) -> UsageInfo:
"""Calculate token usage information"""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=cached_tokens,
)
def aggregate_token_usage(
responses: List[Dict[str, Any]],
n_choices: int = 1,
enable_cache_report: bool = False,
) -> UsageInfo:
"""Aggregate token usage from multiple responses
Args:
responses: List of response dictionaries with meta_info
n_choices: Number of choices per request (for prompt token counting)
enable_cache_report: Whether to include cached token details
Returns:
Aggregated UsageInfo
"""
# Sum completion tokens from all responses
completion_tokens = sum(
response["meta_info"]["completion_tokens"] for response in responses
)
# For prompt tokens, only count every n_choices-th response to avoid double counting
prompt_tokens = sum(
responses[i]["meta_info"]["prompt_tokens"]
for i in range(0, len(responses), n_choices)
)
# Handle cached tokens if cache reporting is enabled
cached_tokens_details = None
if enable_cache_report:
cached_tokens_sum = sum(
response["meta_info"].get("cached_tokens", 0) for response in responses
)
if cached_tokens_sum > 0:
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
return calculate_token_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cached_tokens=cached_tokens_details,
)
def to_openai_style_logprobs( def to_openai_style_logprobs(
input_token_logprobs=None, input_token_logprobs=None,
output_token_logprobs=None, output_token_logprobs=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment