Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
...@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns. ...@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
""" """
import asyncio import asyncio
import copy
import dataclasses import dataclasses
import logging import logging
import os import os
...@@ -12,8 +11,7 @@ import signal ...@@ -12,8 +11,7 @@ import signal
import sys import sys
import threading import threading
import time import time
import uuid from typing import Any, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import grpc import grpc
import zmq import zmq
...@@ -81,10 +79,11 @@ class GrpcReqState: ...@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens: int = 1 last_completion_tokens: int = 1
# Streaming state # Streaming state
last_output_offset: int = 0
stream_finished: bool = False stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming) # Output accumulation
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list) output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
...@@ -140,6 +139,8 @@ class GrpcRequestManager: ...@@ -140,6 +139,8 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition() self.is_pause_cond = asyncio.Condition()
# Metrics # Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time() self.last_receive_tstamp = time.time()
# Crash dump for debugging # Crash dump for debugging
...@@ -157,133 +158,22 @@ class GrpcRequestManager: ...@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput, obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None, request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncGenerator[Union[Dict, List[Dict]], None]: ) -> asyncio.Queue:
""" """
Submit a generation request to the scheduler with n>1 parallel sampling support. Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
""" """
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
request_id = f"grpc-{uuid.uuid4().hex}" async with self.request_counter_lock:
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
# Create and register request state
# TODO: support log_request # TODO: support log_request
# Create request state
state = GrpcReqState( state = GrpcReqState(
request_id=request_id, request_id=request_id,
grpc_context=grpc_context, grpc_context=grpc_context,
...@@ -299,51 +189,19 @@ class GrpcRequestManager: ...@@ -299,51 +189,19 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id state.session_id = obj.session_params.session_id
state.is_session_request = True state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj) self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try: try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj) await self._send_to_scheduler(obj)
except Exception as e:
is_stream = getattr(obj, "stream", False) # Clean up on failure
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
del self.rid_to_state[request_id] del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue
async def embedding_request( async def embedding_request(
self, self,
...@@ -356,7 +214,9 @@ class GrpcRequestManager: ...@@ -356,7 +214,9 @@ class GrpcRequestManager:
""" """
# Generate request ID if not provided # Generate request ID if not provided
if request_id is None: if request_id is None:
request_id = f"grpc-embed-{uuid.uuid4().hex}" async with self.request_counter_lock:
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id obj.rid = request_id
...@@ -495,6 +355,7 @@ class GrpcRequestManager: ...@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request # Extract output for this request
output_data = { output_data = {
"request_id": rid, "request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [], "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None, "finished": batch_out.finished_reasons[i] is not None,
"meta_info": { "meta_info": {
...@@ -506,9 +367,6 @@ class GrpcRequestManager: ...@@ -506,9 +367,6 @@ class GrpcRequestManager:
if batch_out.completion_tokens if batch_out.completion_tokens
else 0 else 0
), ),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": ( "finish_reason": (
str(batch_out.finished_reasons[i]) str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i] if batch_out.finished_reasons[i]
...@@ -517,110 +375,29 @@ class GrpcRequestManager: ...@@ -517,110 +375,29 @@ class GrpcRequestManager:
}, },
} }
# Accumulate input logprobs (only once, usually in first chunk) # Add logprobs if available
if batch_out.input_token_logprobs_val and i < len(
batch_out.input_token_logprobs_val
):
if not state.input_token_logprobs_val:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[i]
)
if batch_out.input_token_logprobs_idx and i < len(
batch_out.input_token_logprobs_idx
):
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[i]
)
if batch_out.input_top_logprobs_val and i < len(
batch_out.input_top_logprobs_val
):
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[i]
)
if batch_out.input_top_logprobs_idx and i < len(
batch_out.input_top_logprobs_idx
):
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[i]
)
# Send input logprobs based on mode
if state.input_token_logprobs_val:
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
state.input_logprobs_sent = True
elif not state.obj.stream and output_data["finished"]:
# Non-streaming: send input logprobs in final chunk
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Add output logprobs if available (RAW - no detokenization!)
if batch_out.output_token_logprobs_val and i < len( if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val batch_out.output_token_logprobs_val
): ):
# Accumulate in state first output_data["logprobs"] = {
state.output_token_logprobs_val.extend( "tokens": batch_out.output_token_logprobs_val[i],
batch_out.output_token_logprobs_val[i] "top_logprobs": (
)
if batch_out.output_token_logprobs_idx and i < len(
batch_out.output_token_logprobs_idx
):
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val and i < len(
batch_out.output_top_logprobs_val
):
state.output_top_logprobs_val.extend(
batch_out.output_top_logprobs_val[i] batch_out.output_top_logprobs_val[i]
) if batch_out.output_top_logprobs_val
if batch_out.output_top_logprobs_idx and i < len( and i < len(batch_out.output_top_logprobs_val)
batch_out.output_top_logprobs_idx else None
): ),
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[i]
)
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def get_part(attr_name):
source_list = getattr(batch_out, attr_name, None)
return (
source_list[i]
if source_list and i < len(source_list)
else []
)
output_data["output_logprobs"] = {
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
"top_logprobs_val": get_part("output_top_logprobs_val"),
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
}
elif output_data["finished"]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data["output_logprobs"] = {
"token_logprobs_val": state.output_token_logprobs_val,
"token_logprobs_idx": state.output_token_logprobs_idx,
"top_logprobs_val": state.output_top_logprobs_val,
"top_logprobs_idx": state.output_top_logprobs_idx,
} }
# Update state for accumulation # Update state
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
if output_data["token_ids"]: if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"]) state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data) await state.out_queue.put(output_data)
# Handle completion # Handle completion
......
...@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format # Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request) tokenized_req = self._convert_generate_request(request)
# Submit to request manager (automatically handles n>1) # Submit to request manager
response_generator = self.request_manager.generate_request( output_queue = await self.request_manager.generate_request(
obj=tokenized_req, obj=tokenized_req,
request_id=request.request_id, request_id=request.request_id,
grpc_context=context, grpc_context=context,
) )
async for output in response_generator: # Stream outputs
# Handle batch responses (for n>1 non-streaming) while True:
if isinstance(output, list): try:
for batch_output in output: # Get output with timeout
if "error" in batch_output: output = await asyncio.wait_for(output_queue.get(), timeout=4)
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id, # Check for errors
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
if "error" in output: if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id, request_id=request.request_id,
...@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
), ),
), ),
) )
elif output.get("finished", False): break
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response( yield self._create_completion_response(
request.request_id, output request.request_id, output
) )
break
else: else:
# Send chunk
yield self._create_chunk_response(request.request_id, output) yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e: except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}") logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse( yield sglang_scheduler_pb2.GenerateResponse(
...@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens=result.get("prompt_tokens", 0), prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0, cached_tokens=0,
embedding_dim=len(result["embedding"]), embedding_dim=len(result["embedding"]),
generation_time=time.time() - self.start_time,
), ),
) )
...@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...") logger.info(f"Sending health check request to request manager...")
# Submit and wait for response # Submit and wait for response
output_generator = self.request_manager.generate_request( output_queue = await self.request_manager.generate_request(
health_request, request_id=rid health_request, request_id=rid
) )
try: try:
# Get first response with timeout # Wait for response with configurable timeout
response = await asyncio.wait_for( response = await asyncio.wait_for(
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
) )
# Clean up # Clean up
...@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob, return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1, logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0, top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False, stream=True, # Always stream for gRPC
lora_id=grpc_req.lora_id if grpc_req.lora_id else None, lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=( token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
), ),
...@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex = None regex = None
json_schema = None json_schema = None
ebnf_grammar = None ebnf_grammar = None
structural_tag = None
if grpc_params.HasField("regex"): if grpc_params.HasField("regex"):
regex = grpc_params.regex regex = grpc_params.regex
...@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema = grpc_params.json_schema json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"): elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar ebnf_grammar = grpc_params.ebnf_grammar
elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag
return SGLSamplingParams( return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0, temperature=grpc_params.temperature or 1.0,
...@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty=grpc_params.repetition_penalty or 1.0, repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128, max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0, min_new_tokens=grpc_params.min_new_tokens or 0,
stop=list(grpc_params.stop) if grpc_params.stop else [], stop=list(grpc_params.stop) if grpc_params.stop else None,
stop_token_ids=( stop_token_ids=(
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else [] list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
), ),
skip_special_tokens=grpc_params.skip_special_tokens, skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens, spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
regex=regex, regex=regex,
json_schema=json_schema, json_schema=json_schema,
ebnf=ebnf_grammar, ebnf=ebnf_grammar,
structural_tag=structural_tag,
n=grpc_params.n or 1, n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos, ignore_eos=grpc_params.ignore_eos,
) )
def _convert_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.LogProbs]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.LogProbs(
token_logprobs=token_logprobs_val,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _create_chunk_response( def _create_chunk_response(
self, request_id: str, output: Dict self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse: ) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response.""" """Create a streaming chunk response."""
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk( chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_ids=output.get("token_ids", []), token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
prompt_tokens=meta_info.get("prompt_tokens", 0), text=output.get("text", ""),
completion_tokens=meta_info.get("completion_tokens", 0), prompt_tokens=0,
cached_tokens=meta_info.get("cached_tokens", 0), completion_tokens=len(output.get("token_ids", [])),
output_logprobs=output_logprobs_proto, cached_tokens=0,
input_logprobs=input_logprobs_proto, generation_time=time.time() - self.start_time,
queue_time=0.0,
), ),
) )
...@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse: ) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response.""" """Create a completion response."""
# Extract meta info and finish reason details # Determine finish reason
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {}) meta_info = output.get("meta_info", {})
finish_reason_data = meta_info.get("finish_reason") if meta_info.get("finish_reason") == "length":
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
# Determine finish reason, default is stop elif meta_info.get("finish_reason") == "eos_token":
finish_reason = "stop" finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete( complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []), output_ids=output.get("token_ids", []),
output_text=output.get("text", ""),
finish_reason=finish_reason, finish_reason=finish_reason,
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get(
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
**matched_stop_kwargs,
), ),
) )
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import ( from openai.types.responses import (
ResponseFunctionToolCall, ResponseFunctionToolCall,
...@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel): ...@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
# For custom metric labels # For customer metric labels
custom_labels: Optional[Dict[str, str]] = None customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens") @field_validator("max_tokens")
@classmethod @classmethod
...@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel): ...@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response.""" """Function response."""
name: Optional[str] = None name: Optional[str] = None
arguments: Optional[str | Dict[str, Any]] = None arguments: Optional[str] = None
class ToolCall(BaseModel): class ToolCall(BaseModel):
...@@ -392,7 +388,7 @@ class Function(BaseModel): ...@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions.""" """Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None]) description: Optional[str] = Field(default=None, examples=[None])
name: str name: Optional[str] = None
parameters: Optional[object] = None parameters: Optional[object] = None
strict: bool = False strict: bool = False
...@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
...@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel): ...@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.", description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
) )
priority: int = Field(default=0, description="Request priority") priority: int = Field(default=0, description="Request priority")
extra_key: Optional[str] = Field(
default=None,
description="Extra key for classifying the request (e.g. cache_salt)",
)
cache_salt: Optional[str] = Field(
default=None, description="Cache salt for request caching"
)
# SGLang-specific sampling parameters # SGLang-specific sampling parameters
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
...@@ -943,16 +928,6 @@ class MessageProcessingResult: ...@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint: Optional[Any] = None tool_call_constraint: Optional[Any] = None
class ToolCallProcessingResult(NamedTuple):
"""Result of processing tool calls in a response."""
tool_calls: Optional[
List[Any]
] # List of ToolCall objects or None if parsing failed
remaining_text: str # Text remaining after parsing tool calls
finish_reason: Dict[str, Any] # Updated finish reason dictionary
class ResponseReasoningTextContent(BaseModel): class ResponseReasoningTextContent(BaseModel):
text: str text: str
type: Literal["reasoning_text"] = "reasoning_text" type: Literal["reasoning_text"] = "reasoning_text"
......
...@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC): ...@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = ( self.allowed_custom_labels = (
set( set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
) )
if isinstance(self.tokenizer_manager.server_args, ServerArgs) if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None else None
) )
...@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC): ...@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return self.create_error_response( return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code message=e.detail, err_type=str(e.status_code), status_code=e.status_code
) )
except ValueError as e:
return self.create_error_response(
message=str(e),
err_type="BadRequest",
status_code=400,
)
except Exception as e: except Exception as e:
logger.exception(f"Error in request: {e}") logger.exception(f"Error in request: {e}")
return self.create_error_response( return self.create_error_response(
...@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC): ...@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return f"{self._request_id_prefix()}{uuid.uuid4().hex}" return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts = []
for key in ["cache_salt", "extra_key"]:
value = getattr(request, key, None)
if value:
if not isinstance(value, str):
raise TypeError(
f"Value of {key} must be a string, but got {type(value).__name__}"
)
parts.append(value)
return "".join(parts) if parts else None
@abstractmethod @abstractmethod
def _convert_to_internal_request( def _convert_to_internal_request(
self, self,
...@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC): ...@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
) )
return json.dumps({"error": error.model_dump()}) return json.dumps({"error": error.model_dump()})
def extract_custom_labels(self, raw_request): def extract_customer_labels(self, raw_request):
if ( if (
not self.allowed_custom_labels not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
): ):
return None return None
custom_labels = None customer_labels = None
header = ( header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
) )
...@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC): ...@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels = None raw_labels = None
if isinstance(raw_labels, dict): if isinstance(raw_labels, dict):
custom_labels = { customer_labels = {
label: value label: value
for label, value in raw_labels.items() for label, value in raw_labels.items()
if label in self.allowed_custom_labels if label in self.allowed_custom_labels
} }
return custom_labels return customer_labels
...@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni ...@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs, LogProbs,
MessageProcessingResult, MessageProcessingResult,
ToolCall, ToolCall,
ToolCallProcessingResult,
ToolChoice,
TopLogprob, TopLogprob,
) )
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
...@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import ( ...@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret, process_hidden_states_from_ret,
to_openai_style_logprobs, to_openai_style_logprobs,
) )
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
...@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super().__init__(tokenizer_manager) super().__init__(tokenizer_manager)
self.template_manager = template_manager self.template_manager = template_manager
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
...@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
): ):
return "Tools cannot be empty if tool choice is set to required." return "Tools cannot be empty if tool choice is set to required."
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
if not request.tools:
return "Tools cannot be empty if tool choice is set to a specific tool."
tool_name = request.tool_choice.function.name
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
if not tool_exists:
return f"Tool '{tool_name}' not found in tools list."
# Validate tool definitions
for i, tool in enumerate(request.tools or []):
if tool.function.parameters is None:
continue
try:
Draft202012Validator.check_schema(tool.function.parameters)
except SchemaError as e:
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
max_output_tokens = request.max_completion_tokens or request.max_tokens max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length server_context_length = self.tokenizer_manager.server_args.context_length
if ( if (
...@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids} prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract custom labels from raw request headers # Extract customer labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request) customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
...@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
custom_labels=custom_labels, customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request
...@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint = parser.get_structure_constraint( tool_call_constraint = parser.get_structure_constraint(
request.tool_choice request.tool_choice
) )
# Handle JSON schema constraint directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
json_schema = get_json_schema_constraint(
request.tools, request.tool_choice
)
tool_call_constraint = ("json_schema", json_schema)
# Use chat template # Use chat template
if self.template_manager.chat_template_name is None: if self.template_manager.chat_template_name is None:
...@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params[constraint_type] = convert_json_schema_to_str( sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True) constraint_value.model_dump(by_alias=True)
) )
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else: else:
sampling_params[constraint_type] = constraint_value sampling_params[constraint_type] = constraint_value
return sampling_params return sampling_params
...@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
# Handle reasoning content # Handle reasoning content
if self.reasoning_parser and request.separate_reasoning: if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
):
reasoning_text, delta = self._process_reasoning_stream( reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request index, delta, reasoning_parser_dict, content, request
) )
...@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content # Handle reasoning content
reasoning_text = None reasoning_text = None
reasoning_parser = self.reasoning_parser reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning: if reasoning_parser and request.separate_reasoning:
is_force_reasoning = ( is_force_reasoning = (
self.template_manager.force_reasoning self.template_manager.force_reasoning
...@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools and request.tools
and self.tool_call_parser and self.tool_call_parser
): ):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls( tool_calls, text, finish_reason = self._process_tool_calls(
text, text, request.tools, finish_reason
request.tools,
finish_reason,
request.tool_choice,
history_tool_calls_cnt,
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
...@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True) token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs) return ChoiceLogprobs(content=token_logprobs)
def _process_tool_call_id(
self,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if self.tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id
def _process_tool_calls( def _process_tool_calls(
self, self,
text: str, text: str,
tools: List[Any], tools: List[Any],
finish_reason: Dict[str, Any], finish_reason: Dict[str, Any],
tool_choice: Optional[Union[str, ToolChoice]] = None, ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
history_tool_calls_cnt: int = 0,
) -> ToolCallProcessingResult:
"""Process tool calls in the response""" """Process tool calls in the response"""
# Handle required or named tool choice
if tool_choice == "required" or (
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
):
# Set finish reason to tool_calls since we're processing tool calls
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text)
tool_calls = []
for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data
call_info = ToolCallItem(
tool_index=i, # Use the loop index as tool_index
name=tool["name"],
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
)
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
tool_calls.append(
ToolCall(
id=tool_id,
index=i,
function=FunctionResponse(
name=tool["name"],
arguments=json.dumps(
tool["parameters"], ensure_ascii=False
),
),
)
)
return ToolCallProcessingResult(tool_calls, "", finish_reason)
except json.JSONDecodeError as e:
logger.error(f"Tool call parsing error: {e}")
return ToolCallProcessingResult(None, text, finish_reason)
# Use parser since output is not constrained by JSON schema
parser = FunctionCallParser(tools, self.tool_call_parser) parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text): if parser.has_tool_call(text):
if finish_reason["type"] == "stop": if finish_reason["type"] == "stop":
...@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [] tool_calls = []
for call_info in call_info_list: for call_info in call_info_list:
tool_id = self._process_tool_call_id( # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
call_info, history_tool_calls_cnt if (
) self.tool_call_parser == "kimi_k2"
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
id=tool_id, id=tool_id,
...@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
), ),
) )
) )
return ToolCallProcessingResult(tool_calls, text, finish_reason) return tool_calls, text, finish_reason
except Exception as e: except Exception as e:
logger.error(f"Tool call parsing error: {e}") logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request # Return error but don't fail the whole request
return ToolCallProcessingResult(None, text, finish_reason) return None, text, finish_reason
return ToolCallProcessingResult(None, text, finish_reason) return None, text, finish_reason
def _process_streaming_logprobs( def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int self, content: Dict[str, Any], n_prev_token: int
...@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or self._get_enable_thinking_from_request(request) or self._get_enable_thinking_from_request(request)
) )
reasoning_parser_dict[index] = ReasoningParser( reasoning_parser_dict[index] = ReasoningParser(
self.reasoning_parser, self.tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning, request.stream_reasoning,
is_force_reasoning, is_force_reasoning,
) )
reasoning_parser = reasoning_parser_dict[index] reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta) return reasoning_parser.parse_stream_chunk(delta)
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool: def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs. """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
...@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
""" """
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs: if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
# For Qwen3 models, `enable_thinking` is supported. # For Qwen3 models, `enable_thinking` is supported.
if self.reasoning_parser in ["qwen3", "glm45"]: if request.chat_template_kwargs.get("enable_thinking") is not None:
return request.chat_template_kwargs.get("enable_thinking", False) return request.chat_template_kwargs.get("enable_thinking")
# For DeepSeek-V3.1 models, `thinking` is supported. # For DeepSeek-V3.1 models, `thinking` is supported.
elif self.reasoning_parser in ["deepseek-v3"]: elif request.chat_template_kwargs.get("thinking") is not None:
return request.chat_template_kwargs.get("thinking", False) return request.chat_template_kwargs.get("thinking")
else: else:
return False return False
return False return False
...@@ -1068,24 +952,12 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1068,24 +952,12 @@ class OpenAIServingChat(OpenAIServingBase):
): ):
"""Process tool calls in streaming response""" """Process tool calls in streaming response"""
if index not in parser_dict: if index not in parser_dict:
# Use JSON detector directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
parser_dict[index] = JsonArrayParser()
else:
parser_dict[index] = FunctionCallParser( parser_dict[index] = FunctionCallParser(
tools=request.tools, tools=request.tools,
tool_call_parser=self.tool_call_parser, tool_call_parser=self.tool_call_parser,
) )
parser = parser_dict[index] parser = parser_dict[index]
# Handle both FunctionCallParser and JsonArrayParser
if isinstance(parser, JsonArrayParser):
result = parser.parse_streaming_increment(delta, request.tools)
normal_text, calls = result.normal_text, result.calls
else:
normal_text, calls = parser.parse_stream_chunk(delta) normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text # Yield normal text
...@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls # Yield tool calls
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls: for call_item in calls:
# Mark that this choice has tool calls # Mark that this choice has tool calls
has_tool_calls[index] = True has_tool_calls[index] = True
...@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call # Tool call ID should be generated only once per tool call
if call_item.name: if call_item.name:
# First chunk: include ID and function name # First chunk: include ID and function name
tool_call_id = self._process_tool_call_id( if self.tool_call_parser == "kimi_k2":
call_item, history_tool_calls_cnt # Align with Kimi-K2 format: functions.{name}:{index}
) tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name function_name = call_item.name
else: else:
# Subsequent chunks: null ID and name for argument deltas # Subsequent chunks: null ID and name for argument deltas
...@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _check_for_unstreamed_tool_args( def _check_for_unstreamed_tool_args(
self, self,
parser: Union[FunctionCallParser, JsonArrayParser], parser: FunctionCallParser,
content: Dict[str, Any], content: Dict[str, Any],
request: ChatCompletionRequest, request: ChatCompletionRequest,
index: int, index: int,
...@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk. even if the model generates the final arguments in the last chunk.
""" """
# Get the detector - either from FunctionCallParser or directly if json detector # Only check if we have tool calls and the parser has tracked data
detector = parser.detector if hasattr(parser, "detector") else parser
# Only check if we have tool calls and the detector has tracked data
if ( if (
not hasattr(detector, "prev_tool_call_arr") not hasattr(parser.detector, "prev_tool_call_arr")
or not detector.prev_tool_call_arr or not parser.detector.prev_tool_call_arr
): ):
return None return None
if ( if (
not hasattr(detector, "streamed_args_for_tool") not hasattr(parser.detector, "streamed_args_for_tool")
or not detector.streamed_args_for_tool or not parser.detector.streamed_args_for_tool
): ):
return None return None
# Get the last tool call that was being processed # Get the last tool call that was being processed
tool_index = len(detector.prev_tool_call_arr) - 1 tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool): if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None return None
# Get expected vs actual arguments # Get expected vs actual arguments
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {}) expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_call = json.dumps(expected_args, ensure_ascii=False) expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = detector.streamed_args_for_tool[tool_index] actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send # Check if there are remaining arguments to send
remaining_call = ( remaining_call = (
......
...@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else: else:
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
# Extract custom labels from raw request headers # Extract customer labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request) customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
...@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states, return_hidden_states=request.return_hidden_states,
rid=request.rid, rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
custom_labels=custom_labels, customer_labels=customer_labels,
) )
return adapted_request, request return adapted_request, request
......
...@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params, sampling_params=sampling_params,
stream=request.stream, stream=request.stream,
rid=request.request_id, rid=request.request_id,
extra_key=self._compute_extra_key(request),
background=request.background, background=request.background,
) )
...@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params, sampling_params=sampling_params,
stream=adapted_request.stream, stream=adapted_request.stream,
rid=request_id, rid=request_id,
extra_key=adapted_request.extra_key,
return_logprob=adapted_request.return_logprob, return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len, logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num, top_logprobs_num=adapted_request.top_logprobs_num,
......
...@@ -231,7 +231,6 @@ class ExpertLocationMetadata: ...@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=( logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map( compute_logical_to_rank_dispatch_physical_map(
server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size, num_gpus=ep_size,
num_physical_experts=num_physical_experts, num_physical_experts=num_physical_experts,
...@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value): ...@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap) # TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map( def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor,
num_gpus: int, num_gpus: int,
num_physical_experts: int, num_physical_experts: int,
...@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map( ...@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
): ):
r = random.Random(seed) r = random.Random(seed)
num_local_gpu_physical_experts = num_physical_experts // num_gpus num_local_physical_experts = num_physical_experts // num_gpus
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype dtype = logical_to_all_physical_map.dtype
...@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map( ...@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id physical_expert_id
for physical_expert_id in candidate_physical_expert_ids for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert( if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts physical_expert_id, num_local_physical_experts
) )
== gpu_id == gpu_id
] ]
if len(same_gpu_physical_expert_ids) > 0: if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0] output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]
# 3. Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item() num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor( output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r), _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
...@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw( ...@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def _compute_gpu_id_of_physical_expert( def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_gpu_physical_experts: int physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_gpu_physical_experts
def _compute_node_id_of_physical_expert(
physical_expert_id: int, num_local_host_physical_experts: int
) -> int: ) -> int:
return physical_expert_id // num_local_host_physical_experts return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List: def _fair_choices(arr: List, k: int, r: random.Random) -> List:
......
...@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector ...@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector from sglang.srt.function_call.step3_detector import Step3Detector
from sglang.srt.function_call.utils import get_json_schema_constraint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -179,8 +178,8 @@ class FunctionCallParser: ...@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag = self.get_structure_tag() strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag) return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice): elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice) ebnf = self.get_ebnf(tool_choice)
return ("json_schema", json_schema) return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf( def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]] self, tool_choice: Union[ToolChoice, Literal["required"]]
......
...@@ -39,7 +39,7 @@ def parse_arguments(json_value): ...@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class Glm4MoeDetector(BaseFormatDetector): class Glm4MoeDetector(BaseFormatDetector):
""" """
Detector for GLM-4.5 and GLM-4.6 models. Detector for GLM-4.5 models.
Assumes function call format: Assumes function call format:
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call> <tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
""" """
...@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector): ...@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>" self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def has_tool_call(self, text: str) -> bool: def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call.""" """Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
...@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector): ...@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self, new_text: str, tools: List[Tool] self, new_text: str, tools: List[Tool]
) -> StreamingParseResult: ) -> StreamingParseResult:
""" """
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format. Streaming incremental parsing tool calls for GLM-4.5 format.
""" """
self._buffer += new_text self._buffer += new_text
current_text = self._buffer current_text = self._buffer
......
import json
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class JsonArrayParser(BaseFormatDetector):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def __init__(self):
super().__init__()
# Configure for JSON array parsing
self.bot_token = "["
self.eot_token = "]"
self.tool_call_separator = ","
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return "[" in text or "{" in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parse JSON tool calls using the base class implementation.
"""
raise NotImplementedError(
"Detect and parse not supported for JSON schema constraints."
)
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError(
"EBNF generation is not supported for JSON schema constraints."
)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
return super().parse_streaming_increment(new_text, tools)
def structure_info(self) -> callable:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError("structure_info not used for JSON schema constraints")
import json import json
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE from typing import Any, Tuple
from typing import Any, List, Literal, Optional, Tuple, Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
def _find_common_prefix(s1: str, s2: str) -> str: def _find_common_prefix(s1: str, s2: str) -> str:
prefix = "" prefix = ""
...@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: ...@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
""" """
try: try:
return (partial_json_parser.loads(input_str, flags), len(input_str)) return (partial_json_parser.loads(input_str, flags), len(input_str))
except (JSONDecodeError, IndexError) as e: except JSONDecodeError as e:
msg = getattr(e, "msg", str(e)) if "Extra data" in e.msg:
if "Extra data" in msg or "pop from empty list" in msg: dec = JSONDecoder()
start = WHITESPACE.match(input_str, 0).end() return dec.raw_decode(input_str)
obj, end = JSONDecoder().raw_decode(input_str, start)
return obj, end
raise raise
...@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool: ...@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return True return True
except JSONDecodeError: except JSONDecodeError:
return False return False
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs = {}
for tool in tools:
if tool.function.parameters is None:
continue
defs = tool.function.parameters.get("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has "
"multiple schemas, which is not "
"supported."
)
else:
all_defs[def_name] = def_schema
return all_defs
def _get_tool_schema(tool: Tool) -> dict:
return {
"properties": {
"name": {"type": "string", "enum": [tool.function.name]},
"parameters": (
tool.function.parameters
if tool.function.parameters
else {"type": "object", "properties": {}}
),
},
"required": ["name", "parameters"],
}
def get_json_schema_constraint(
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[dict]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if isinstance(tool_choice, ToolChoice):
# For specific function choice, return the user's parameters schema directly
fn_name = tool_choice.function.name
for tool in tools:
if tool.function.name == fn_name:
return {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": _get_tool_schema(tool),
}
return None
elif tool_choice == "required":
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
return None
...@@ -36,9 +36,9 @@ message SamplingParams { ...@@ -36,9 +36,9 @@ message SamplingParams {
float presence_penalty = 6; float presence_penalty = 6;
float repetition_penalty = 7; float repetition_penalty = 7;
optional int32 max_new_tokens = 8; int32 max_new_tokens = 8;
repeated string stop = 9; repeated string stop = 9;
repeated uint32 stop_token_ids = 10; repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11; bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12; bool spaces_between_special_tokens = 12;
...@@ -47,24 +47,24 @@ message SamplingParams { ...@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13; string regex = 13;
string json_schema = 14; string json_schema = 14;
string ebnf_grammar = 15; string ebnf_grammar = 15;
string structural_tag = 16;
} }
// LoRA adapter // LoRA adapter
string lora_path = 17; string lora_path = 16;
// Speculative decoding // Speculative decoding
int32 n = 18; // Number of samples int32 n = 17; // Number of samples
// Token healing // Token healing
bool token_healing = 19; bool token_healing = 18;
// Additional parameters // Additional parameters
int32 min_new_tokens = 20; int32 min_new_tokens = 19;
bool ignore_eos = 21; bool ignore_eos = 20;
bool no_stop_trim = 22; bool no_stop_trim = 21;
int32 stream_interval = 23; int32 stream_interval = 22;
map<string, float> logit_bias = 24; map<string, float> logit_bias = 23;
string structural_tag = 24;
// Custom parameters for extensibility // Custom parameters for extensibility
google.protobuf.Struct custom_params = 25; google.protobuf.Struct custom_params = 25;
...@@ -98,7 +98,7 @@ message GenerateRequest { ...@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5; bool return_logprob = 5;
int32 logprob_start_len = 6; int32 logprob_start_len = 6;
int32 top_logprobs_num = 7; int32 top_logprobs_num = 7;
repeated uint32 token_ids_logprob = 8; repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9; bool return_hidden_states = 9;
// For disaggregated serving // For disaggregated serving
...@@ -122,14 +122,11 @@ message GenerateRequest { ...@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing // For load balancing
int32 dp_balance_id = 17; int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
} }
message TokenizedInput { message TokenizedInput {
string original_text = 1; // For reference string original_text = 1; // For reference
repeated uint32 input_ids = 2; repeated int32 input_ids = 2;
} }
message MultimodalInputs { message MultimodalInputs {
...@@ -166,50 +163,51 @@ message GenerateResponse { ...@@ -166,50 +163,51 @@ message GenerateResponse {
} }
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated tokens (incremental chunk) // Generated token
repeated uint32 token_ids = 1; int32 token_id = 1;
string text = 2;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 3;
int32 completion_tokens = 3; int32 completion_tokens = 4;
int32 cached_tokens = 4; int32 cached_tokens = 5;
// Output logprobs (if requested) - incremental for streaming // Logprobs (if requested)
LogProbs output_logprobs = 5; LogProbs logprobs = 6;
// Hidden states (if requested) // Hidden states (if requested)
repeated float hidden_states = 6; repeated float hidden_states = 7;
// Input logprobs (if requested) - only in first chunk // Metadata
LogProbs input_logprobs = 7; float generation_time = 8; // Time to generate this token
int32 queue_time = 9; // Time spent in queue
} }
message GenerateComplete { message GenerateComplete {
// Final output // Final output
repeated uint32 output_ids = 1; repeated int32 output_ids = 1;
string output_text = 2;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string finish_reason = 2; // Finish reason
enum FinishReason {
// Token usage counts // The model generated a stop sequence.
int32 prompt_tokens = 3; STOP = 0;
int32 completion_tokens = 4; // The model reached the maximum generation length.
int32 cached_tokens = 5; LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 3;
// Output logprobs if requested (cumulative) // All logprobs if requested
LogProbs output_logprobs = 6; repeated LogProbs all_logprobs = 11;
// All hidden states if requested // All hidden states if requested
repeated HiddenStates all_hidden_states = 7; repeated HiddenStates all_hidden_states = 12;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
} }
message GenerateError { message GenerateError {
...@@ -224,11 +222,15 @@ message LogProbs { ...@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position // Top logprobs at each position
repeated TopLogProbs top_logprobs = 3; repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
} }
message TopLogProbs { message TopLogProbs {
repeated float values = 1; repeated float values = 1;
repeated int32 token_ids = 2; repeated int32 token_ids = 2;
repeated string token_texts = 3;
} }
message HiddenStates { message HiddenStates {
...@@ -283,9 +285,10 @@ message EmbedComplete { ...@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata // Additional metadata
int32 embedding_dim = 4; int32 embedding_dim = 4;
float generation_time = 5;
// For batch embeddings // For batch embeddings
repeated Embedding batch_embeddings = 5; repeated Embedding batch_embeddings = 6;
} }
message Embedding { message Embedding {
......
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2 from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping from collections.abc import Iterable as _Iterable, Mapping as _Mapping
...@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union ...@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message): class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params") __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message): class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value") __slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int] KEY_FIELD_NUMBER: _ClassVar[int]
...@@ -34,7 +35,6 @@ class SamplingParams(_message.Message): ...@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int] REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int] LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int] TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
...@@ -43,6 +43,7 @@ class SamplingParams(_message.Message): ...@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int] STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int] CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float temperature: float
top_p: float top_p: float
...@@ -59,7 +60,6 @@ class SamplingParams(_message.Message): ...@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str regex: str
json_schema: str json_schema: str
ebnf_grammar: str ebnf_grammar: str
structural_tag: str
lora_path: str lora_path: str
n: int n: int
token_healing: bool token_healing: bool
...@@ -68,8 +68,9 @@ class SamplingParams(_message.Message): ...@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool no_stop_trim: bool
stream_interval: int stream_interval: int
logit_bias: _containers.ScalarMap[str, float] logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message): class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
...@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message): ...@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message): class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream") __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int] REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message): ...@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int] LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str request_id: str
tokenized: TokenizedInput tokenized: TokenizedInput
mm_inputs: MultimodalInputs mm_inputs: MultimodalInputs
...@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message): ...@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str lora_id: str
data_parallel_rank: int data_parallel_rank: int
dp_balance_id: int dp_balance_id: int
stream: bool def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message): class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids") __slots__ = ("original_text", "input_ids")
...@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message): ...@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message): class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs") __slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int] QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
cached_tokens: int cached_tokens: int
output_logprobs: LogProbs logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float] hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs generation_time: float
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ... queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs") __slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int] output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str output_text: str
prompt_tokens: int finish_reason: GenerateComplete.FinishReason
completion_tokens: int all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
cached_tokens: int
output_logprobs: LogProbs
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
class GenerateError(_message.Message): class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details") __slots__ = ("message", "http_status_code", "details")
...@@ -214,22 +219,26 @@ class GenerateError(_message.Message): ...@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message): class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs") __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float] token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int] token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ... token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message): class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids") __slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int] VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float] values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int] token_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ... token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message): class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position") __slots__ = ("values", "layer", "position")
...@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message): ...@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message): class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings") __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int] EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int] EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float] embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int prompt_tokens: int
cached_tokens: int cached_tokens: int
embedding_dim: int embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding] batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ... def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message): class Embedding(_message.Message):
__slots__ = ("values", "index") __slots__ = ("values", "index")
......
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services.""" """Client and server classes corresponding to protobuf-defined services."""
import grpc import grpc
......
...@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config return config
def _load_deepseek_v32_model(
model_path: str,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
):
# first get the local path
local_path = download_from_hf(model_path)
# then load the config file in json
config_file = os.path.join(local_path, "config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Can't find config file in {local_path}.")
with open(config_file, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
config_json["model_type"] = "deepseek_v3"
tmp_path = os.path.join(local_path, "_tmp_config_folder")
os.makedirs(tmp_path, exist_ok=True)
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
with open(unique_path, "w") as f:
json.dump(config_json, f)
return AutoConfig.from_pretrained(
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
@lru_cache_frozenset(maxsize=32) @lru_cache_frozenset(maxsize=32)
def get_config( def get_config(
model: str, model: str,
...@@ -140,9 +171,17 @@ def get_config( ...@@ -140,9 +171,17 @@ def get_config(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir() model = client.get_local_dir()
try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
config = _load_deepseek_v32_model(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if ( if (
config.architectures is not None config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM" and config.architectures[0] == "Phi4MMForCausalLM"
......
...@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3 assert len(k.shape) == 3
assert len(v.shape) == 3 assert len(v.shape) == 3
if ( if forward_batch.forward_mode.is_extend():
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if kv_indices.shape[0] == 0: if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func( o = flash_attn_varlen_func(
q, q,
......
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import custom_ops
import torch import torch
import torch_npu import torch_npu
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
...@@ -36,6 +37,8 @@ class ForwardMetadata: ...@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend): class AscendAttnBackend(AttentionBackend):
...@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend): ...@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla: if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
self.native_attn = TorchNativeAttnBackend(model_runner) self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {} self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
...@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend): ...@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False self.graph_mode = False
...@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend): ...@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
self.graph_metadata[bs] = metadata self.graph_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_metadata = metadata
...@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend): ...@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0) metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0) metadata.block_tables[bs:, :].fill_(0)
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
self.forward_metadata = metadata self.forward_metadata = metadata
self.graph_mode = True self.graph_mode = True
...@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend): ...@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 0
def forward_sparse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: torch.Tensor = None,
):
is_prefill = forward_batch.forward_mode.is_extend()
if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables
if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_out
def forward_extend( def forward_extend(
self, self,
q, q,
...@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend): ...@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
): ):
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if not self.use_mla: if not self.use_mla:
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
...@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend): ...@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
): ):
if is_mla_preprocess_enabled(): if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache # MLAPO does saving kv_cache
save_kv_cache = False save_kv_cache = False
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if self.graph_mode: if self.graph_mode:
return self.forward_decode_graph( return self.forward_decode_graph(
......
import logging
logger = logging.getLogger(__name__)
ATTENTION_BACKENDS = {} ATTENTION_BACKENDS = {}
...@@ -66,6 +62,13 @@ def create_ascend_backend(runner): ...@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner) return AscendAttnBackend(runner)
@register_attention_backend("nsa")
def create_nsa_backend(runner):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(runner)
@register_attention_backend("triton") @register_attention_backend("triton")
def create_triton_backend(runner): def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, ( assert not runner.model_config.is_encoder_decoder, (
...@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner): ...@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner) return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend): @register_attention_backend("hybrid_linear_attn")
""" def create_hybrid_linear_attn_backend(runner):
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert ( assert (
runner.server_args.attention_backend == "triton" runner.is_hybrid_gdn
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend." ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend, HybridLinearAttnBackend,
MambaAttnBackend, MambaAttnBackend,
) )
from sglang.srt.utils import is_blackwell, is_npu
if is_npu():
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(runner)
elif is_blackwell():
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
full_attn_backend = TritonAttnBackend(runner)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
full_attn_backend = FlashAttentionBackend(runner)
linear_attn_backend = MambaAttnBackend(runner) linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend( return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers full_attn_backend, linear_attn_backend, full_attn_layers
) )
return full_attn_backend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment