Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
import copy
import dataclasses
import logging
import os
......@@ -12,8 +11,7 @@ import signal
import sys
import threading
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import grpc
import zmq
......@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens: int = 1
# Streaming state
last_output_offset: int = 0
stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
# Output accumulation
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
......@@ -140,6 +139,8 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition()
# Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time()
# Crash dump for debugging
......@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
) -> asyncio.Queue:
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
"""
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-{uuid.uuid4().hex}"
async with self.request_counter_lock:
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
# Create and register request state
# TODO: support log_request
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
......@@ -299,51 +189,19 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id
state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj)
is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue
async def embedding_request(
self,
......@@ -356,7 +214,9 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-embed-{uuid.uuid4().hex}"
async with self.request_counter_lock:
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
......@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request
output_data = {
"request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
......@@ -506,9 +367,6 @@ class GrpcRequestManager:
if batch_out.completion_tokens
else 0
),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
......@@ -517,110 +375,29 @@ class GrpcRequestManager:
},
}
# Accumulate input logprobs (only once, usually in first chunk)
if batch_out.input_token_logprobs_val and i < len(
batch_out.input_token_logprobs_val
):
if not state.input_token_logprobs_val:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[i]
)
if batch_out.input_token_logprobs_idx and i < len(
batch_out.input_token_logprobs_idx
):
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[i]
)
if batch_out.input_top_logprobs_val and i < len(
batch_out.input_top_logprobs_val
):
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[i]
)
if batch_out.input_top_logprobs_idx and i < len(
batch_out.input_top_logprobs_idx
):
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[i]
)
# Send input logprobs based on mode
if state.input_token_logprobs_val:
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
state.input_logprobs_sent = True
elif not state.obj.stream and output_data["finished"]:
# Non-streaming: send input logprobs in final chunk
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Add output logprobs if available (RAW - no detokenization!)
# Add logprobs if available
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
):
# Accumulate in state first
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[i]
)
if batch_out.output_token_logprobs_idx and i < len(
batch_out.output_token_logprobs_idx
):
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val and i < len(
batch_out.output_top_logprobs_val
):
state.output_top_logprobs_val.extend(
output_data["logprobs"] = {
"tokens": batch_out.output_token_logprobs_val[i],
"top_logprobs": (
batch_out.output_top_logprobs_val[i]
)
if batch_out.output_top_logprobs_idx and i < len(
batch_out.output_top_logprobs_idx
):
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[i]
)
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def get_part(attr_name):
source_list = getattr(batch_out, attr_name, None)
return (
source_list[i]
if source_list and i < len(source_list)
else []
)
output_data["output_logprobs"] = {
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
"top_logprobs_val": get_part("output_top_logprobs_val"),
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
}
elif output_data["finished"]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data["output_logprobs"] = {
"token_logprobs_val": state.output_token_logprobs_val,
"token_logprobs_idx": state.output_token_logprobs_idx,
"top_logprobs_val": state.output_top_logprobs_val,
"top_logprobs_idx": state.output_top_logprobs_idx,
}
# Update state for accumulation
if batch_out.output_top_logprobs_val
and i < len(batch_out.output_top_logprobs_val)
else None
),
}
# Update state
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data)
# Handle completion
......
......@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
# Submit to request manager (automatically handles n>1)
response_generator = self.request_manager.generate_request(
# Submit to request manager
output_queue = await self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
async for output in response_generator:
# Handle batch responses (for n>1 non-streaming)
if isinstance(output, list):
for batch_output in output:
if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
# Stream outputs
while True:
try:
# Get output with timeout
output = await asyncio.wait_for(output_queue.get(), timeout=4)
# Check for errors
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
......@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
elif output.get("finished", False):
break
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response(
request.request_id, output
)
break
else:
# Send chunk
yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse(
......@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0,
embedding_dim=len(result["embedding"]),
generation_time=time.time() - self.start_time,
),
)
......@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
output_generator = self.request_manager.generate_request(
output_queue = await self.request_manager.generate_request(
health_request, request_id=rid
)
try:
# Get first response with timeout
# Wait for response with configurable timeout
response = await asyncio.wait_for(
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
)
# Clean up
......@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False,
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
stream=True, # Always stream for gRPC
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
......@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex = None
json_schema = None
ebnf_grammar = None
structural_tag = None
if grpc_params.HasField("regex"):
regex = grpc_params.regex
......@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar
elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag
return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0,
......@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0,
stop=list(grpc_params.stop) if grpc_params.stop else [],
stop=list(grpc_params.stop) if grpc_params.stop else None,
stop_token_ids=(
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
),
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
regex=regex,
json_schema=json_schema,
ebnf=ebnf_grammar,
structural_tag=structural_tag,
n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos,
)
def _convert_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.LogProbs]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.LogProbs(
token_logprobs=token_logprobs_val,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
text=output.get("text", ""),
prompt_tokens=0,
completion_tokens=len(output.get("token_ids", [])),
cached_tokens=0,
generation_time=time.time() - self.start_time,
queue_time=0.0,
),
)
......@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
# Extract meta info and finish reason details
# Determine finish reason
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {})
finish_reason_data = meta_info.get("finish_reason")
# Determine finish reason, default is stop
finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
if meta_info.get("finish_reason") == "length":
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token":
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []),
output_text=output.get("text", ""),
finish_reason=finish_reason,
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get(
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
**matched_stop_kwargs,
),
)
......
......@@ -16,7 +16,7 @@
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
......@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens")
@classmethod
......@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str | Dict[str, Any]] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
......@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: str
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
......@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
......@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
)
priority: int = Field(default=0, description="Request priority")
extra_key: Optional[str] = Field(
default=None,
description="Extra key for classifying the request (e.g. cache_salt)",
)
cache_salt: Optional[str] = Field(
default=None, description="Cache salt for request caching"
)
# SGLang-specific sampling parameters
frequency_penalty: float = 0.0
......@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint: Optional[Any] = None
class ToolCallProcessingResult(NamedTuple):
"""Result of processing tool calls in a response."""
tool_calls: Optional[
List[Any]
] # List of ToolCall objects or None if parsing failed
remaining_text: str # Text remaining after parsing tool calls
finish_reason: Dict[str, Any] # Updated finish reason dictionary
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"
......
......@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
......@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
except ValueError as e:
return self.create_error_response(
message=str(e),
err_type="BadRequest",
status_code=400,
)
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(
......@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts = []
for key in ["cache_salt", "extra_key"]:
value = getattr(request, key, None)
if value:
if not isinstance(value, str):
raise TypeError(
f"Value of {key} must be a string, but got {type(value).__name__}"
)
parts.append(value)
return "".join(parts) if parts else None
@abstractmethod
def _convert_to_internal_request(
self,
......@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
)
return json.dumps({"error": error.model_dump()})
def extract_custom_labels(self, raw_request):
def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
custom_labels = None
customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
......@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels = None
if isinstance(raw_labels, dict):
custom_labels = {
customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
return custom_labels
return customer_labels
......@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
......@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs,
MessageProcessingResult,
ToolCall,
ToolCallProcessingResult,
ToolChoice,
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
......@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
......@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
def _request_id_prefix(self) -> str:
return "chatcmpl-"
......@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
):
return "Tools cannot be empty if tool choice is set to required."
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
if not request.tools:
return "Tools cannot be empty if tool choice is set to a specific tool."
tool_name = request.tool_choice.function.name
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
if not tool_exists:
return f"Tool '{tool_name}' not found in tools list."
# Validate tool definitions
for i, tool in enumerate(request.tools or []):
if tool.function.parameters is None:
continue
try:
Draft202012Validator.check_schema(tool.function.parameters)
except SchemaError as e:
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length
if (
......@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
......@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
customer_labels=customer_labels,
)
return adapted_request, request
......@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Handle JSON schema constraint directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
json_schema = get_json_schema_constraint(
request.tools, request.tool_choice
)
tool_call_constraint = ("json_schema", json_schema)
# Use chat template
if self.template_manager.chat_template_name is None:
......@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
......@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
if self.reasoning_parser and request.separate_reasoning:
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
......@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
reasoning_parser = self.reasoning_parser
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning:
is_force_reasoning = (
self.template_manager.force_reasoning
......@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools
and self.tool_call_parser
):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls(
text,
request.tools,
finish_reason,
request.tool_choice,
history_tool_calls_cnt,
text, request.tools, finish_reason
)
choice_data = ChatCompletionResponseChoice(
......@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs)
def _process_tool_call_id(
self,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if self.tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id
def _process_tool_calls(
self,
text: str,
tools: List[Any],
finish_reason: Dict[str, Any],
tool_choice: Optional[Union[str, ToolChoice]] = None,
history_tool_calls_cnt: int = 0,
) -> ToolCallProcessingResult:
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response"""
# Handle required or named tool choice
if tool_choice == "required" or (
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
):
# Set finish reason to tool_calls since we're processing tool calls
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text)
tool_calls = []
for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data
call_info = ToolCallItem(
tool_index=i, # Use the loop index as tool_index
name=tool["name"],
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
)
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
tool_calls.append(
ToolCall(
id=tool_id,
index=i,
function=FunctionResponse(
name=tool["name"],
arguments=json.dumps(
tool["parameters"], ensure_ascii=False
),
),
)
)
return ToolCallProcessingResult(tool_calls, "", finish_reason)
except json.JSONDecodeError as e:
logger.error(f"Tool call parsing error: {e}")
return ToolCallProcessingResult(None, text, finish_reason)
# Use parser since output is not constrained by JSON schema
parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
......@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text)
tool_calls = []
for call_info in call_info_list:
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
if (
self.tool_call_parser == "kimi_k2"
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_calls.append(
ToolCall(
id=tool_id,
......@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
)
)
return ToolCallProcessingResult(tool_calls, text, finish_reason)
return tool_calls, text, finish_reason
except Exception as e:
logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request
return ToolCallProcessingResult(None, text, finish_reason)
return None, text, finish_reason
return ToolCallProcessingResult(None, text, finish_reason)
return None, text, finish_reason
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
......@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or self._get_enable_thinking_from_request(request)
)
reasoning_parser_dict[index] = ReasoningParser(
self.reasoning_parser,
self.tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning,
is_force_reasoning,
)
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
......@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
"""
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
# For Qwen3 models, `enable_thinking` is supported.
if self.reasoning_parser in ["qwen3", "glm45"]:
return request.chat_template_kwargs.get("enable_thinking", False)
if request.chat_template_kwargs.get("enable_thinking") is not None:
return request.chat_template_kwargs.get("enable_thinking")
# For DeepSeek-V3.1 models, `thinking` is supported.
elif self.reasoning_parser in ["deepseek-v3"]:
return request.chat_template_kwargs.get("thinking", False)
elif request.chat_template_kwargs.get("thinking") is not None:
return request.chat_template_kwargs.get("thinking")
else:
return False
return False
......@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
# Use JSON detector directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
parser_dict[index] = JsonArrayParser()
else:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
parser = parser_dict[index]
# Handle both FunctionCallParser and JsonArrayParser
if isinstance(parser, JsonArrayParser):
result = parser.parse_streaming_increment(delta, request.tools)
normal_text, calls = result.normal_text, result.calls
else:
normal_text, calls = parser.parse_stream_chunk(delta)
normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text
if normal_text:
......@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
......@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = self._process_tool_call_id(
call_item, history_tool_calls_cnt
)
if self.tool_call_parser == "kimi_k2":
# Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
......@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _check_for_unstreamed_tool_args(
self,
parser: Union[FunctionCallParser, JsonArrayParser],
parser: FunctionCallParser,
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
......@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Get the detector - either from FunctionCallParser or directly if json detector
detector = parser.detector if hasattr(parser, "detector") else parser
# Only check if we have tool calls and the detector has tracked data
# Only check if we have tool calls and the parser has tracked data
if (
not hasattr(detector, "prev_tool_call_arr")
or not detector.prev_tool_call_arr
not hasattr(parser.detector, "prev_tool_call_arr")
or not parser.detector.prev_tool_call_arr
):
return None
if (
not hasattr(detector, "streamed_args_for_tool")
or not detector.streamed_args_for_tool
not hasattr(parser.detector, "streamed_args_for_tool")
or not parser.detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
tool_index = len(detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = detector.streamed_args_for_tool[tool_index]
actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (
......
......@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": prompt}
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
......@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
customer_labels=customer_labels,
)
return adapted_request, request
......
......@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=request.stream,
rid=request.request_id,
extra_key=self._compute_extra_key(request),
background=request.background,
)
......@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=adapted_request.stream,
rid=request_id,
extra_key=adapted_request.extra_key,
return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num,
......
......@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
......@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
......@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
):
r = random.Random(seed)
num_local_gpu_physical_experts = num_physical_experts // num_gpus
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
......@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts
physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]
# 3. Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
......@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_gpu_physical_experts: int
) -> int:
return physical_expert_id // num_local_gpu_physical_experts
def _compute_node_id_of_physical_expert(
physical_expert_id: int, num_local_host_physical_experts: int
physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_host_physical_experts
return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
......
......@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
from sglang.srt.function_call.utils import get_json_schema_constraint
logger = logging.getLogger(__name__)
......@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema)
ebnf = self.get_ebnf(tool_choice)
return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]
......
......@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class Glm4MoeDetector(BaseFormatDetector):
"""
Detector for GLM-4.5 and GLM-4.6 models.
Detector for GLM-4.5 models.
Assumes function call format:
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
"""
......@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
"""Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
......@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self._buffer += new_text
current_text = self._buffer
......
import json
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class JsonArrayParser(BaseFormatDetector):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def __init__(self):
super().__init__()
# Configure for JSON array parsing
self.bot_token = "["
self.eot_token = "]"
self.tool_call_separator = ","
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return "[" in text or "{" in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parse JSON tool calls using the base class implementation.
"""
raise NotImplementedError(
"Detect and parse not supported for JSON schema constraints."
)
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError(
"EBNF generation is not supported for JSON schema constraints."
)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
return super().parse_streaming_increment(new_text, tools)
def structure_info(self) -> callable:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError("structure_info not used for JSON schema constraints")
import json
from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union
from typing import Any, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
......@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except (JSONDecodeError, IndexError) as e:
msg = getattr(e, "msg", str(e))
if "Extra data" in msg or "pop from empty list" in msg:
start = WHITESPACE.match(input_str, 0).end()
obj, end = JSONDecoder().raw_decode(input_str, start)
return obj, end
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
......@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return True
except JSONDecodeError:
return False
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs = {}
for tool in tools:
if tool.function.parameters is None:
continue
defs = tool.function.parameters.get("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has "
"multiple schemas, which is not "
"supported."
)
else:
all_defs[def_name] = def_schema
return all_defs
def _get_tool_schema(tool: Tool) -> dict:
return {
"properties": {
"name": {"type": "string", "enum": [tool.function.name]},
"parameters": (
tool.function.parameters
if tool.function.parameters
else {"type": "object", "properties": {}}
),
},
"required": ["name", "parameters"],
}
def get_json_schema_constraint(
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[dict]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if isinstance(tool_choice, ToolChoice):
# For specific function choice, return the user's parameters schema directly
fn_name = tool_choice.function.name
for tool in tools:
if tool.function.name == fn_name:
return {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": _get_tool_schema(tool),
}
return None
elif tool_choice == "required":
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
return None
......@@ -36,9 +36,9 @@ message SamplingParams {
float presence_penalty = 6;
float repetition_penalty = 7;
optional int32 max_new_tokens = 8;
int32 max_new_tokens = 8;
repeated string stop = 9;
repeated uint32 stop_token_ids = 10;
repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12;
......@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
string structural_tag = 16;
}
// LoRA adapter
string lora_path = 17;
string lora_path = 16;
// Speculative decoding
int32 n = 18; // Number of samples
int32 n = 17; // Number of samples
// Token healing
bool token_healing = 19;
bool token_healing = 18;
// Additional parameters
int32 min_new_tokens = 20;
bool ignore_eos = 21;
bool no_stop_trim = 22;
int32 stream_interval = 23;
map<string, float> logit_bias = 24;
int32 min_new_tokens = 19;
bool ignore_eos = 20;
bool no_stop_trim = 21;
int32 stream_interval = 22;
map<string, float> logit_bias = 23;
string structural_tag = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;
......@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5;
int32 logprob_start_len = 6;
int32 top_logprobs_num = 7;
repeated uint32 token_ids_logprob = 8;
repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9;
// For disaggregated serving
......@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
}
message TokenizedInput {
string original_text = 1; // For reference
repeated uint32 input_ids = 2;
repeated int32 input_ids = 2;
}
message MultimodalInputs {
......@@ -166,50 +163,51 @@ message GenerateResponse {
}
message GenerateStreamChunk {
// Generated tokens (incremental chunk)
repeated uint32 token_ids = 1;
// Generated token
int32 token_id = 1;
string text = 2;
// Cumulative counts
int32 prompt_tokens = 2;
int32 completion_tokens = 3;
int32 cached_tokens = 4;
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
// Output logprobs (if requested) - incremental for streaming
LogProbs output_logprobs = 5;
// Logprobs (if requested)
LogProbs logprobs = 6;
// Hidden states (if requested)
repeated float hidden_states = 6;
repeated float hidden_states = 7;
// Input logprobs (if requested) - only in first chunk
LogProbs input_logprobs = 7;
// Metadata
float generation_time = 8; // Time to generate this token
int32 queue_time = 9; // Time spent in queue
}
message GenerateComplete {
// Final output
repeated uint32 output_ids = 1;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string finish_reason = 2;
// Token usage counts
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
repeated int32 output_ids = 1;
string output_text = 2;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 3;
// Output logprobs if requested (cumulative)
LogProbs output_logprobs = 6;
// All logprobs if requested
repeated LogProbs all_logprobs = 11;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
repeated HiddenStates all_hidden_states = 12;
}
message GenerateError {
......@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
repeated string token_texts = 3;
}
message HiddenStates {
......@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata
int32 embedding_dim = 4;
float generation_time = 5;
// For batch embeddings
repeated Embedding batch_embeddings = 5;
repeated Embedding batch_embeddings = 6;
}
message Embedding {
......
......@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
......@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
......@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
......@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
......@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str
json_schema: str
ebnf_grammar: str
structural_tag: str
lora_path: str
n: int
token_healing: bool
......@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
......@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
......@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
......@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
......@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
generation_time: float
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
output_text: str
finish_reason: GenerateComplete.FinishReason
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
......@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids")
__slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
......@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")
......
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
......
......@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config
def _load_deepseek_v32_model(
model_path: str,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
):
# first get the local path
local_path = download_from_hf(model_path)
# then load the config file in json
config_file = os.path.join(local_path, "config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Can't find config file in {local_path}.")
with open(config_file, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
config_json["model_type"] = "deepseek_v3"
tmp_path = os.path.join(local_path, "_tmp_config_folder")
os.makedirs(tmp_path, exist_ok=True)
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
with open(unique_path, "w") as f:
json.dump(config_json, f)
return AutoConfig.from_pretrained(
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
......@@ -140,9 +171,17 @@ def get_config(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
config = _load_deepseek_v32_model(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"
......
......@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3
assert len(v.shape) == 3
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if forward_batch.forward_mode.is_extend():
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
......
......@@ -3,6 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import custom_ops
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
......@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
......@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
......@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
......@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
......@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
self.forward_metadata = metadata
self.graph_mode = True
......@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_sparse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: torch.Tensor = None,
):
is_prefill = forward_batch.forward_mode.is_extend()
if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables
if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_out
def forward_extend(
self,
q,
......@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
......@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if self.graph_mode:
return self.forward_decode_graph(
......
import logging
logger = logging.getLogger(__name__)
ATTENTION_BACKENDS = {}
......@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner)
@register_attention_backend("nsa")
def create_nsa_backend(runner):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(runner)
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, (
......@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
@register_attention_backend("hybrid_linear_attn")
def create_hybrid_linear_attn_backend(runner):
assert (
runner.is_hybrid_gdn
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
if is_npu():
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(runner)
elif is_blackwell():
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
full_attn_backend = TritonAttnBackend(runner)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return full_attn_backend
full_attn_backend = FlashAttentionBackend(runner)
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment