Unverified Commit 53ca1552 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)

parent a23bdeaf
......@@ -22,17 +22,19 @@ repos:
rev: 5.13.2
hooks:
- id: isort
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
- id: ruff
args: [--select=F401, --fixable=F401]
files: ^(benchmark/|docs/|examples/)
exclude: \.ipynb$
exclude: \.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black-jupyter
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
......@@ -42,7 +44,11 @@ repos:
exclude: |
(?x)^(
test/srt/test_reasoning_parser\.py|
docs/advanced_features/vlm_query\.ipynb
docs/advanced_features/vlm_query\.ipynb|
python/sglang/srt/grpc/.*_pb2\.py|
python/sglang/srt/grpc/.*_pb2_grpc\.py|
python/sglang/srt/grpc/.*_pb2\.pyi|
python/sglang/srt/grpc/.*_pb2_grpc\.pyi
)$
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
......
"""
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
import dataclasses
import logging
import os
import signal
import sys
import threading
import time
from typing import Any, Dict, List, Optional, Union
import grpc
import zmq
import zmq.asyncio
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
HealthCheckOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class GrpcSignalHandler:
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
def __init__(self, grpc_manager):
self.grpc_manager = grpc_manager
def sigterm_handler(self, signum=None, frame=None):
"""Handle SIGTERM by gracefully shutting down gRPC server."""
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
)
self.grpc_manager.gracefully_exit = True
def running_phase_sigquit_handler(self, signum=None, frame=None):
"""Handle SIGQUIT from failed scheduler process."""
logger.error(
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
)
logger.info(
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
)
# Just exit cleanly - the scheduler handles crash dumps
kill_process_tree(os.getpid(), include_parent=True)
@dataclasses.dataclass
class GrpcReqState:
"""State tracking for a gRPC request."""
# Request identification
request_id: str
grpc_context: Optional[grpc.aio.ServicerContext]
# Communication
out_queue: asyncio.Queue
finished: bool
event: asyncio.Event
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
# Metrics (same as TokenizerManager's ReqState)
created_time: float
finished_time: float = 0.0
first_token_time: float = 0.0
last_time: float = 0.0
last_completion_tokens: int = 1
# Streaming state
last_output_offset: int = 0
stream_finished: bool = False
# Output accumulation
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
# Session state
session_id: Optional[str] = None
is_session_request: bool = False
class GrpcRequestManager:
"""
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
behaviors without tokenization.
"""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
"""Initialize the gRPC request manager."""
self.server_args = server_args
self.port_args = port_args
# ZMQ Communication Setup (same pattern as TokenizerManager)
context = zmq.asyncio.Context(2)
# Socket for receiving outputs from scheduler
self.recv_from_scheduler = get_zmq_socket(
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
)
# Socket for sending requests to scheduler
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
)
# State Management (from TokenizerManager)
self.rid_to_state: Dict[str, GrpcReqState] = {}
self.asyncio_tasks: set = set()
self.gracefully_exit = False
self.no_create_loop = False
self.event_loop = None
# Pause/Resume Control
self.is_pause = False
self.is_pause_cond = asyncio.Condition()
# Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time()
# Crash dump for debugging
self.crash_dump_request_list = []
self.crash_dump_performed = False
logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: "
f"recv={port_args.detokenizer_ipc_name}, "
f"send={port_args.scheduler_input_ipc_name}"
)
async def generate_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> asyncio.Queue:
"""
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
"""
# Generate request ID if not provided
if request_id is None:
async with self.request_counter_lock:
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
# TODO: support log_request
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
out_queue=asyncio.Queue(),
finished=False,
event=asyncio.Event(),
obj=obj,
created_time=time.time(),
)
# Track session if needed
if hasattr(obj, "session_params") and obj.session_params:
state.session_id = obj.session_params.session_id
state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try:
await self._send_to_scheduler(obj)
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue
async def embedding_request(
self,
obj: TokenizedEmbeddingReqInput,
request_id: Optional[str] = None,
) -> asyncio.Future:
"""
Submit an embedding request to the scheduler.
Returns a future that will contain the embedding result.
"""
# Generate request ID if not provided
if request_id is None:
async with self.request_counter_lock:
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=None,
out_queue=asyncio.Queue(),
finished=False,
event=asyncio.Event(),
obj=obj,
created_time=time.time(),
)
# Register state
self.rid_to_state[request_id] = state
# Create future for result
future = asyncio.Future()
# Send to scheduler
try:
await self._send_to_scheduler(obj)
except Exception as e:
del self.rid_to_state[request_id]
future.set_exception(e)
return future
# Wait for result in background
async def wait_for_result():
try:
# Wait for completion
await state.event.wait()
# Get result from queue
result = await state.out_queue.get()
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
# Clean up
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
asyncio.create_task(wait_for_result())
return future
async def abort_request(self, request_id: str) -> bool:
"""Abort a running request."""
if request_id not in self.rid_to_state:
return False
# Send abort to scheduler
abort_req = AbortReq(rid=request_id)
try:
await self._send_to_scheduler(abort_req)
except Exception as e:
logger.error(f"Failed to send abort request: {e}")
return False
# Mark as finished
state = self.rid_to_state.get(request_id)
if state:
state.finished = True
state.stream_finished = True
state.event.set()
# Send abort notification to output queue
await state.out_queue.put({"error": "Request aborted", "abort": True})
return True
async def pause_generation(self):
"""Pause generation processing."""
async with self.is_pause_cond:
self.is_pause = True
logger.info("Generation paused")
async def resume_generation(self):
"""Resume generation processing."""
async with self.is_pause_cond:
self.is_pause = False
self.is_pause_cond.notify_all()
logger.info("Generation resumed")
async def handle_loop(self):
"""
Main event loop - processes outputs from scheduler.
Mimics TokenizerManager's handle_loop.
"""
while not self.gracefully_exit:
try:
# Receive from scheduler
recv_obj = await self.recv_from_scheduler.recv_pyobj()
self.last_receive_tstamp = time.time()
# Check for pause
async with self.is_pause_cond:
while self.is_pause:
await self.is_pause_cond.wait()
# Handle different output types
if isinstance(recv_obj, BatchTokenIDOut):
await self._handle_batch_output(recv_obj)
elif isinstance(recv_obj, BatchEmbeddingOut):
await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj)
else:
logger.warning(f"Unknown output type: {type(recv_obj)}")
except zmq.error.Again:
# Timeout, check if we should exit
if self.gracefully_exit:
break
continue
except Exception as e:
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
if self.gracefully_exit:
break
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
continue
state = self.rid_to_state[rid]
# Update metrics
now = time.time()
if state.first_token_time == 0.0:
state.first_token_time = now
state.last_time = now
# Extract output for this request
output_data = {
"request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
"prompt_tokens": (
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
),
"completion_tokens": (
batch_out.completion_tokens[i]
if batch_out.completion_tokens
else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
else None
),
},
}
# Add logprobs if available
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
):
output_data["logprobs"] = {
"tokens": batch_out.output_token_logprobs_val[i],
"top_logprobs": (
batch_out.output_top_logprobs_val[i]
if batch_out.output_top_logprobs_val
and i < len(batch_out.output_top_logprobs_val)
else None
),
}
# Update state
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data)
# Handle completion
if output_data["finished"]:
state.finished = True
state.finished_time = now
state.stream_finished = True
state.event.set()
# Remove from tracking after a delay
async def cleanup():
await asyncio.sleep(5.0)
if rid in self.rid_to_state:
del self.rid_to_state[rid]
asyncio.create_task(cleanup())
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
"""Handle batch embedding output from scheduler."""
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
continue
state = self.rid_to_state[rid]
# Create result
result = {
"request_id": rid,
"embedding": batch_out.embeddings[i],
"prompt_tokens": (
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
),
"finish_reason": (
batch_out.finish_reason[i] if batch_out.finish_reason else None
),
}
# Send result
await state.out_queue.put(result)
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.event.set()
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
"""Handle health check output from scheduler."""
rid = health_out.rid
if rid not in self.rid_to_state:
logger.warning(f"Health check output for unknown request: {rid}")
return
state = self.rid_to_state[rid]
# Create health check result
result = {
"request_id": rid,
"healthy": True, # If we got a response, scheduler is healthy
"output_text": (
health_out.output_str if hasattr(health_out, "output_str") else ""
),
"finish_reason": (
health_out.finish_reason
if hasattr(health_out, "finish_reason")
else "stop"
),
}
# Send result
await state.out_queue.put(result)
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.event.set()
async def _send_to_scheduler(self, obj):
"""Send an object to the scheduler via ZMQ."""
try:
self.send_to_scheduler.send_pyobj(obj)
except Exception as e:
logger.error(f"Failed to send to scheduler: {e}")
raise
def record_request_for_crash_dump(self, obj):
"""Record request for potential crash dump."""
if len(self.crash_dump_request_list) < 100:
self.crash_dump_request_list.append(
{
"time": time.time(),
"request_id": getattr(obj, "rid", "unknown"),
"type": type(obj).__name__,
}
)
async def shutdown(self):
"""Gracefully shutdown the request manager."""
logger.info("Shutting down GrpcRequestManager")
self.gracefully_exit = True
# Cancel all pending requests
for rid, state in self.rid_to_state.items():
if not state.finished:
await state.out_queue.put(
{"error": "Server shutting down", "shutdown": True}
)
state.finished = True
state.event.set()
# Wait for tasks to complete
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Close ZMQ sockets
self.recv_from_scheduler.close()
self.send_to_scheduler.close()
logger.info("GrpcRequestManager shutdown complete")
def get_server_info(self) -> Dict[str, Any]:
"""Get server information for health checks."""
return {
"active_requests": len(self.rid_to_state),
"paused": self.is_pause,
"last_receive_time": self.last_receive_tstamp,
}
def auto_create_handle_loop(self):
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
if self.no_create_loop:
return
self.no_create_loop = True
loop = asyncio.get_event_loop()
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.handle_loop))
)
self.event_loop = loop
# We cannot add signal handler when the grpc manager is not in
# the main thread due to the CPython limitation.
if threading.current_thread() is threading.main_thread():
signal_handler = GrpcSignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
loop.add_signal_handler(
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
)
else:
logger.warning(
"Signal handler is not added because the grpc request manager is "
"not in the main thread. This disables graceful shutdown of the "
"grpc request manager when SIGTERM is received."
)
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
async def sigterm_watchdog(self):
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
while not self.gracefully_exit:
await asyncio.sleep(1.0)
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)
"""
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
Uses GrpcRequestManager for orchestration without tokenization.
"""
import argparse
import asyncio
import logging
import multiprocessing as mp
import os
import signal
import time
from concurrent import futures
from typing import AsyncIterator, Dict, Optional, Tuple
import grpc
from grpc_reflection.v1alpha import reflection
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
def _launch_scheduler_process_only(
server_args: ServerArgs,
port_args: Optional[PortArgs] = None,
) -> Tuple[Dict, PortArgs, list]:
"""
Launch only the scheduler process(es) without tokenizer/detokenizer.
Returns scheduler info, port args, and list of scheduler processes.
"""
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
# Allocate ports for inter-process communications
if port_args is None:
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Prepare model and tokenizer paths
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
server_args.model_path, server_args.tokenizer_path
)
scheduler_procs = []
if server_args.dp_size == 1:
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = []
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
None,
writer,
None,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)
# TODO(CatherineSue): handle cases for multi-node
# Wait for all scheduler processes to be ready
scheduler_infos = []
for i, reader in enumerate(scheduler_pipe_readers):
try:
data = reader.recv()
except EOFError:
logger.error(
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
)
scheduler_procs[i].join()
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
if data.get("status") != "ready":
raise RuntimeError(
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
)
scheduler_infos.append(data)
logger.info(
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
)
# Return the first scheduler's info (they should all be the same)
return scheduler_infos[0], port_args, scheduler_procs
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
"""
Standalone gRPC service implementation using GrpcRequestManager.
Fully separated from HTTP server with its own process and no shared globals.
"""
def __init__(
self,
request_manager: GrpcRequestManager,
server_args: ServerArgs,
model_info: Dict,
):
"""Initialize the standalone gRPC service."""
self.request_manager = request_manager
self.server_args = server_args
self.model_info = model_info
self.start_time = time.time()
# Start the request manager's event loop using auto_create_handle_loop
self.request_manager.auto_create_handle_loop()
logger.info("Standalone gRPC scheduler service initialized")
async def Generate(
self,
request: sglang_scheduler_pb2.GenerateRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
"""Handle generation requests with streaming responses."""
logger.info(f"Generation request: {request.request_id}")
try:
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
# Submit to request manager
output_queue = await self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
# Stream outputs
while True:
try:
# Get output with timeout
output = await asyncio.wait_for(output_queue.get(), timeout=4)
# Check for errors
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=output["error"],
http_status_code=(
"500" if "abort" not in output else "499"
),
),
)
break
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response(
request.request_id, output
)
break
else:
# Send chunk
yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=str(e),
http_status_code="500",
details=get_exception_traceback(),
),
)
async def Embed(
self,
request: sglang_scheduler_pb2.EmbedRequest,
context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.EmbedResponse:
"""Handle embedding requests."""
logger.info(f"Embedding request: {request.request_id}")
try:
# Convert request
tokenized_req = self._convert_embed_request(request)
# Submit to request manager
future = await self.request_manager.embedding_request(
obj=tokenized_req,
request_id=request.request_id,
)
# Wait for result
result = await future
# Create response
return sglang_scheduler_pb2.EmbedResponse(
request_id=request.request_id,
complete=sglang_scheduler_pb2.EmbedComplete(
embedding=result["embedding"],
prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0,
embedding_dim=len(result["embedding"]),
generation_time=time.time() - self.start_time,
),
)
except Exception as e:
logger.error(f"Embed failed: {e}\n{get_exception_traceback()}")
return sglang_scheduler_pb2.EmbedResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.EmbedError(
message=str(e),
code="INTERNAL_ERROR",
details=get_exception_traceback(),
),
)
async def HealthCheck(
self,
request: sglang_scheduler_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.HealthCheckResponse:
"""Health check by generating from client input."""
try:
# Check if request manager is shutting down
if self.request_manager.gracefully_exit:
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message="Server shutting down"
)
# Extract tokenized input from request
if not request.HasField("tokenized"):
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message="Tokenized input required for health check"
)
input_text = request.tokenized.original_text
input_ids = list(request.tokenized.input_ids)
# Create health check request
rid = f"HEALTH_CHECK_GRPC_{time.time()}"
health_request = TokenizedGenerateReqInput(
rid=rid,
input_text=input_text,
input_ids=input_ids,
sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0),
stream=False,
mm_inputs=None,
return_logprob=False,
logprob_start_len=-1,
top_logprobs_num=0,
token_ids_logprob=None,
)
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
output_queue = await self.request_manager.generate_request(
health_request, request_id=rid
)
try:
# Wait for response with configurable timeout
response = await asyncio.wait_for(
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
)
# Clean up
if rid in self.request_manager.rid_to_state:
del self.request_manager.rid_to_state[rid]
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=True, message="Health check passed"
)
except asyncio.TimeoutError:
# Clean up on timeout
if rid in self.request_manager.rid_to_state:
del self.request_manager.rid_to_state[rid]
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message="Health check timeout"
)
except Exception as e:
logger.error(f"Health check failed: {e}")
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message=f"Health check error: {str(e)}"
)
async def Abort(
self,
request: sglang_scheduler_pb2.AbortRequest,
context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.AbortResponse:
"""Abort an ongoing request."""
logger.info(f"Aborting request: {request.request_id}")
try:
success = await self.request_manager.abort_request(request.request_id)
return sglang_scheduler_pb2.AbortResponse(
success=success,
message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
)
except Exception as e:
logger.error(f"Abort failed: {e}")
return sglang_scheduler_pb2.AbortResponse(
success=False,
message=str(e),
)
# Helper methods for request/response conversion
def _convert_generate_request(
self, grpc_req: sglang_scheduler_pb2.GenerateRequest
) -> TokenizedGenerateReqInput:
"""Convert gRPC GenerateRequest to internal format."""
# Extract tokenized input
if not grpc_req.HasField("tokenized"):
raise ValueError("Tokenized input must be provided")
input_text = grpc_req.tokenized.original_text
input_ids = list(grpc_req.tokenized.input_ids)
# Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
# Create request
return TokenizedGenerateReqInput(
rid=grpc_req.request_id,
input_text=input_text,
input_ids=input_ids,
mm_inputs=None, # TODO: implement mm support
sampling_params=sampling_params,
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=True, # Always stream for gRPC
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
)
def _convert_embed_request(
self, grpc_req: sglang_scheduler_pb2.EmbedRequest
) -> TokenizedEmbeddingReqInput:
"""Convert gRPC EmbedRequest to internal format."""
# Extract tokenized input
if not grpc_req.HasField("tokenized"):
raise ValueError("Tokenized input must be provided")
input_text = grpc_req.tokenized.original_text
input_ids = list(grpc_req.tokenized.input_ids)
return TokenizedEmbeddingReqInput(
rid=grpc_req.request_id,
input_text=input_text,
input_ids=input_ids,
)
def _convert_sampling_params(
self, grpc_params: sglang_scheduler_pb2.SamplingParams
) -> SGLSamplingParams:
"""Convert gRPC SamplingParams to internal format."""
# Handle constraint types
regex = None
json_schema = None
ebnf_grammar = None
if grpc_params.HasField("regex"):
regex = grpc_params.regex
elif grpc_params.HasField("json_schema"):
json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar
return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0,
top_p=grpc_params.top_p or 1.0,
top_k=grpc_params.top_k or -1,
min_p=grpc_params.min_p or 0.0,
frequency_penalty=grpc_params.frequency_penalty or 0.0,
presence_penalty=grpc_params.presence_penalty or 0.0,
repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0,
stop=list(grpc_params.stop) if grpc_params.stop else None,
stop_token_ids=(
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
),
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
regex=regex,
json_schema=json_schema,
ebnf=ebnf_grammar,
n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos,
)
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
text=output.get("text", ""),
prompt_tokens=0,
completion_tokens=len(output.get("token_ids", [])),
cached_tokens=0,
generation_time=time.time() - self.start_time,
queue_time=0.0,
),
)
def _create_completion_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
# Determine finish reason
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {})
if meta_info.get("finish_reason") == "length":
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token":
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []),
output_text=output.get("text", ""),
finish_reason=finish_reason,
),
)
async def shutdown(self):
"""Shutdown the service."""
logger.info("Shutting down gRPC service")
# Shutdown request manager (handles its own tasks)
await self.request_manager.shutdown()
async def serve_grpc(
server_args: ServerArgs,
model_info: Optional[Dict] = None,
):
"""Start the standalone gRPC server with integrated scheduler."""
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
server_args=server_args,
)
# Update model info from scheduler info
if model_info is None:
model_info = {
"model_name": server_args.model_path,
"max_context_length": scheduler_info.get(
"max_total_num_tokens", server_args.context_length or 8192
),
"vocab_size": scheduler_info.get("vocab_size", 128256),
"supports_vision": scheduler_info.get("supports_vision", False),
"model_type": scheduler_info.get("model_type", "transformer"),
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
"pad_token_id": scheduler_info.get("pad_token_id", 0),
"bos_token_id": scheduler_info.get("bos_token_id", 1),
}
# Create request manager with the correct port args
request_manager = GrpcRequestManager(
server_args=server_args,
port_args=port_args,
)
# Create gRPC server
server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
],
)
# Add service
servicer = SGLangSchedulerServicer(
request_manager=request_manager,
server_args=server_args,
model_info=model_info,
)
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
# Enable reflection
SERVICE_NAMES = (
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
# Start server
listen_addr = f"{server_args.host}:{server_args.port}"
server.add_insecure_port(listen_addr)
logger.info(f"Starting standalone gRPC server on {listen_addr}")
await server.start()
# Handle shutdown signals
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
stop_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
try:
await stop_event.wait()
finally:
logger.info("Shutting down gRPC server")
await servicer.shutdown()
await server.stop(5.0)
# Terminate scheduler processes
for i, proc in enumerate(scheduler_procs):
if proc and proc.is_alive():
logger.info(f"Terminating scheduler process {i}...")
proc.terminate()
proc.join(timeout=5.0)
if proc.is_alive():
logger.warning(f"Force killing scheduler process {i}...")
proc.kill()
proc.join()
def main():
"""Main entry point for standalone gRPC server."""
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
# Model arguments
parser.add_argument("--model-path", type=str, required=True, help="Model path")
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
parser.add_argument("--context-length", type=int, help="Context length")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
# Runtime arguments
parser.add_argument(
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
)
parser.add_argument(
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
)
parser.add_argument(
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
)
parser.add_argument(
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
)
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
args = parser.parse_args()
# Convert to ServerArgs with gRPC host and port
server_args = ServerArgs(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path or args.model_path,
context_length=args.context_length,
tp_size=args.tp_size,
dp_size=args.dp_size,
max_running_requests=args.max_running_requests,
max_total_tokens=args.max_total_tokens,
max_prefill_tokens=args.max_prefill_tokens,
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
# Override with gRPC server host and port
host=args.host,
port=args.port,
)
# Run server
asyncio.run(
serve_grpc(
server_args=server_args,
)
)
if __name__ == "__main__":
main()
syntax = "proto3";
package sglang.grpc.scheduler;
import "google/protobuf/timestamp.proto";
import "google/protobuf/struct.proto";
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
service SglangScheduler {
// Submit a generation request (supports streaming)
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
// Submit an embedding request
rpc Embed(EmbedRequest) returns (EmbedResponse);
// Health check and metrics
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running request
rpc Abort(AbortRequest) returns (AbortResponse);
}
// =====================
// Common Types
// =====================
// Sampling parameters matching SGLang's SamplingParams
message SamplingParams {
float temperature = 1;
float top_p = 2;
int32 top_k = 3;
float min_p = 4;
float frequency_penalty = 5;
float presence_penalty = 6;
float repetition_penalty = 7;
int32 max_new_tokens = 8;
repeated string stop = 9;
repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12;
// Structured generation
oneof constraint {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
}
// LoRA adapter
string lora_path = 16;
// Speculative decoding
int32 n = 17; // Number of samples
// Token healing
bool token_healing = 18;
// Additional parameters
int32 min_new_tokens = 19;
bool ignore_eos = 20;
bool no_stop_trim = 21;
int32 stream_interval = 22;
map<string, float> logit_bias = 23;
string structural_tag = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;
}
// Disaggregated serving parameters
message DisaggregatedParams {
string bootstrap_host = 1;
int32 bootstrap_port = 2;
int32 bootstrap_room = 3;
}
// =====================
// Generate Request
// =====================
message GenerateRequest {
string request_id = 1;
// Input must be tokenized (no raw text)
TokenizedInput tokenized = 2;
// Multimodal inputs
MultimodalInputs mm_inputs = 3;
// Generation parameters
SamplingParams sampling_params = 4;
// Return options
bool return_logprob = 5;
int32 logprob_start_len = 6;
int32 top_logprobs_num = 7;
repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9;
// For disaggregated serving
DisaggregatedParams disaggregated_params = 10;
// Custom logit processor (serialized)
string custom_logit_processor = 11;
// Request metadata
google.protobuf.Timestamp timestamp = 12;
bool log_metrics = 13;
// Input embeddings (alternative to text/tokens)
repeated float input_embeds = 14;
// LoRA adapter ID (if pre-loaded)
string lora_id = 15;
// Data parallel routing
int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 17;
}
message TokenizedInput {
string original_text = 1; // For reference
repeated int32 input_ids = 2;
}
message MultimodalInputs {
// Simplified multimodal handling - actual data processed by tokenizer
repeated string image_urls = 1;
repeated string video_urls = 2;
repeated string audio_urls = 3;
// Pre-processed multimodal features (if available)
google.protobuf.Struct processed_features = 4;
// Raw data for direct processing
repeated bytes image_data = 5;
repeated bytes video_data = 6;
repeated bytes audio_data = 7;
// Modality metadata
repeated string modalities = 8;
}
// =====================
// Generate Response
// =====================
message GenerateResponse {
string request_id = 1;
// Response type
oneof response {
GenerateStreamChunk chunk = 2;
GenerateComplete complete = 3;
GenerateError error = 4;
}
}
message GenerateStreamChunk {
// Generated token
int32 token_id = 1;
string text = 2;
// Cumulative counts
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
// Logprobs (if requested)
LogProbs logprobs = 6;
// Hidden states (if requested)
repeated float hidden_states = 7;
// Metadata
float generation_time = 8; // Time to generate this token
int32 queue_time = 9; // Time spent in queue
}
message GenerateComplete {
// Final output
repeated int32 output_ids = 1;
string output_text = 2;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 3;
// All logprobs if requested
repeated LogProbs all_logprobs = 11;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 12;
}
message GenerateError {
string message = 1;
string http_status_code = 2;
string details = 3;
}
message LogProbs {
repeated float token_logprobs = 1;
repeated int32 token_ids = 2;
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
repeated string token_texts = 3;
}
message HiddenStates {
repeated float values = 1;
int32 layer = 2;
int32 position = 3;
}
// =====================
// Embedding Request
// =====================
message EmbedRequest {
string request_id = 1;
// Input must be tokenized (no raw text)
TokenizedInput tokenized = 2;
// Multimodal inputs
MultimodalInputs mm_inputs = 4;
// Dummy sampling params for compatibility
// EmbedRequest doesn't use sampling_params
SamplingParams sampling_params = 5;
bool log_metrics = 6;
// Token type IDs for models that require them
repeated int32 token_type_ids = 7;
// Data parallel routing
int32 data_parallel_rank = 8;
// For cross-encoder requests
bool is_cross_encoder = 9;
repeated string texts = 10; // For cross-encoder batch
}
message EmbedResponse {
string request_id = 1;
oneof response {
EmbedComplete complete = 2;
EmbedError error = 3;
}
}
message EmbedComplete {
repeated float embedding = 1;
int32 prompt_tokens = 2;
int32 cached_tokens = 3;
// Additional metadata
int32 embedding_dim = 4;
float generation_time = 5;
// For batch embeddings
repeated Embedding batch_embeddings = 6;
}
message Embedding {
repeated float values = 1;
int32 index = 2;
}
message EmbedError {
string message = 1;
string code = 2;
string details = 3;
}
// =====================
// Management Operations
// =====================
message HealthCheckRequest {
// Input for health test generation (must be tokenized)
TokenizedInput tokenized = 1;
}
message HealthCheckResponse {
bool healthy = 1;
string message = 2;
}
message AbortRequest {
string request_id = 1;
string reason = 2;
}
message AbortResponse {
bool success = 1;
string message = 2;
}
// =====================
// Additional Operations (Future)
// =====================
// Load LoRA adapter
message LoadLoRARequest {
string adapter_id = 1;
string adapter_path = 2;
int32 rank = 3;
}
message LoadLoRAResponse {
bool success = 1;
string adapter_id = 2;
string message = 3;
}
// Unload LoRA adapter
message UnloadLoRARequest {
string adapter_id = 1;
}
message UnloadLoRAResponse {
bool success = 1;
string message = 2;
}
// Update weights
message UpdateWeightsRequest {
oneof source {
string disk_path = 1;
bytes tensor_data = 2;
string remote_url = 3;
}
string weight_name = 4;
}
message UpdateWeightsResponse {
bool success = 1;
string message = 2;
}
// Get internal state for debugging
message GetInternalStateRequest {
repeated string state_keys = 1;
}
message GetInternalStateResponse {
google.protobuf.Struct state = 1;
}
// Set internal state for testing
message SetInternalStateRequest {
google.protobuf.Struct state = 1;
}
message SetInternalStateResponse {
bool success = 1;
string message = 2;
}
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: sglang_scheduler.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'sglang_scheduler.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc7\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x11\n\tlora_path\x18\x10 \x01(\t\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x15\n\rtoken_healing\x18\x12 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x13 \x01(\x05\x12\x12\n\nignore_eos\x18\x14 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x15 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x16 \x01(\x05\x12H\n\nlogit_bias\x18\x17 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12\x16\n\x0estructural_tag\x18\x18 \x01(\t\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xf5\x01\n\x13GenerateStreamChunk\x12\x10\n\x08token_id\x18\x01 \x01(\x05\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x31\n\x08logprobs\x18\x06 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x07 \x03(\x02\x12\x17\n\x0fgeneration_time\x18\x08 \x01(\x02\x12\x12\n\nqueue_time\x18\t \x01(\x05\"\xcd\x02\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12K\n\rfinish_reason\x18\x03 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x35\n\x0c\x61ll_logprobs\x18\x0b \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x0c \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xbc\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12\x17\n\x0fgeneration_time\x18\x05 \x01(\x02\x12:\n\x10\x62\x61tch_embeddings\x18\x06 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sglang_scheduler_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001'
_globals['_SAMPLINGPARAMS']._serialized_start=113
_globals['_SAMPLINGPARAMS']._serialized_end=824
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=762
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=810
_globals['_DISAGGREGATEDPARAMS']._serialized_start=826
_globals['_DISAGGREGATEDPARAMS']._serialized_end=919
_globals['_GENERATEREQUEST']._serialized_start=922
_globals['_GENERATEREQUEST']._serialized_end=1539
_globals['_TOKENIZEDINPUT']._serialized_start=1541
_globals['_TOKENIZEDINPUT']._serialized_end=1599
_globals['_MULTIMODALINPUTS']._serialized_start=1602
_globals['_MULTIMODALINPUTS']._serialized_end=1813
_globals['_GENERATERESPONSE']._serialized_start=1816
_globals['_GENERATERESPONSE']._serialized_end=2043
_globals['_GENERATESTREAMCHUNK']._serialized_start=2046
_globals['_GENERATESTREAMCHUNK']._serialized_end=2291
_globals['_GENERATECOMPLETE']._serialized_start=2294
_globals['_GENERATECOMPLETE']._serialized_end=2627
_globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2551
_globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2627
_globals['_GENERATEERROR']._serialized_start=2629
_globals['_GENERATEERROR']._serialized_end=2704
_globals['_LOGPROBS']._serialized_start=2707
_globals['_LOGPROBS']._serialized_end=2839
_globals['_TOPLOGPROBS']._serialized_start=2841
_globals['_TOPLOGPROBS']._serialized_end=2910
_globals['_HIDDENSTATES']._serialized_start=2912
_globals['_HIDDENSTATES']._serialized_end=2975
_globals['_EMBEDREQUEST']._serialized_start=2978
_globals['_EMBEDREQUEST']._serialized_end=3308
_globals['_EMBEDRESPONSE']._serialized_start=3311
_globals['_EMBEDRESPONSE']._serialized_end=3468
_globals['_EMBEDCOMPLETE']._serialized_start=3471
_globals['_EMBEDCOMPLETE']._serialized_end=3659
_globals['_EMBEDDING']._serialized_start=3661
_globals['_EMBEDDING']._serialized_end=3703
_globals['_EMBEDERROR']._serialized_start=3705
_globals['_EMBEDERROR']._serialized_end=3765
_globals['_HEALTHCHECKREQUEST']._serialized_start=3767
_globals['_HEALTHCHECKREQUEST']._serialized_end=3845
_globals['_HEALTHCHECKRESPONSE']._serialized_start=3847
_globals['_HEALTHCHECKRESPONSE']._serialized_end=3902
_globals['_ABORTREQUEST']._serialized_start=3904
_globals['_ABORTREQUEST']._serialized_end=3954
_globals['_ABORTRESPONSE']._serialized_start=3956
_globals['_ABORTRESPONSE']._serialized_end=4005
_globals['_LOADLORAREQUEST']._serialized_start=4007
_globals['_LOADLORAREQUEST']._serialized_end=4080
_globals['_LOADLORARESPONSE']._serialized_start=4082
_globals['_LOADLORARESPONSE']._serialized_end=4154
_globals['_UNLOADLORAREQUEST']._serialized_start=4156
_globals['_UNLOADLORAREQUEST']._serialized_end=4195
_globals['_UNLOADLORARESPONSE']._serialized_start=4197
_globals['_UNLOADLORARESPONSE']._serialized_end=4251
_globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4253
_globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4372
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4374
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4431
_globals['_GETINTERNALSTATEREQUEST']._serialized_start=4433
_globals['_GETINTERNALSTATEREQUEST']._serialized_end=4478
_globals['_GETINTERNALSTATERESPONSE']._serialized_start=4480
_globals['_GETINTERNALSTATERESPONSE']._serialized_end=4546
_globals['_SETINTERNALSTATEREQUEST']._serialized_start=4548
_globals['_SETINTERNALSTATEREQUEST']._serialized_end=4613
_globals['_SETINTERNALSTATERESPONSE']._serialized_start=4615
_globals['_SETINTERNALSTATERESPONSE']._serialized_end=4675
_globals['_SGLANGSCHEDULER']._serialized_start=4678
_globals['_SGLANGSCHEDULER']._serialized_end=5060
# @@protoc_insertion_point(module_scope)
import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: float
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
TOP_P_FIELD_NUMBER: _ClassVar[int]
TOP_K_FIELD_NUMBER: _ClassVar[int]
MIN_P_FIELD_NUMBER: _ClassVar[int]
FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int]
MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
top_k: int
min_p: float
frequency_penalty: float
presence_penalty: float
repetition_penalty: float
max_new_tokens: int
stop: _containers.RepeatedScalarFieldContainer[str]
stop_token_ids: _containers.RepeatedScalarFieldContainer[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
regex: str
json_schema: str
ebnf_grammar: str
lora_path: str
n: int
token_healing: bool
min_new_tokens: int
ignore_eos: bool
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int]
BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int]
BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int]
bootstrap_host: str
bootstrap_port: int
bootstrap_room: int
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int]
LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int]
RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int]
CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int]
TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
sampling_params: SamplingParams
return_logprob: bool
logprob_start_len: int
top_logprobs_num: int
token_ids_logprob: _containers.RepeatedScalarFieldContainer[int]
return_hidden_states: bool
disaggregated_params: DisaggregatedParams
custom_logit_processor: str
timestamp: _timestamp_pb2.Timestamp
log_metrics: bool
input_embeds: _containers.RepeatedScalarFieldContainer[float]
lora_id: str
data_parallel_rank: int
dp_balance_id: int
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int]
INPUT_IDS_FIELD_NUMBER: _ClassVar[int]
original_text: str
input_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ...
class MultimodalInputs(_message.Message):
__slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities")
IMAGE_URLS_FIELD_NUMBER: _ClassVar[int]
VIDEO_URLS_FIELD_NUMBER: _ClassVar[int]
AUDIO_URLS_FIELD_NUMBER: _ClassVar[int]
PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int]
IMAGE_DATA_FIELD_NUMBER: _ClassVar[int]
VIDEO_DATA_FIELD_NUMBER: _ClassVar[int]
AUDIO_DATA_FIELD_NUMBER: _ClassVar[int]
MODALITIES_FIELD_NUMBER: _ClassVar[int]
image_urls: _containers.RepeatedScalarFieldContainer[str]
video_urls: _containers.RepeatedScalarFieldContainer[str]
audio_urls: _containers.RepeatedScalarFieldContainer[str]
processed_features: _struct_pb2.Struct
image_data: _containers.RepeatedScalarFieldContainer[bytes]
video_data: _containers.RepeatedScalarFieldContainer[bytes]
audio_data: _containers.RepeatedScalarFieldContainer[bytes]
modalities: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ...
class GenerateResponse(_message.Message):
__slots__ = ("request_id", "chunk", "complete", "error")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
CHUNK_FIELD_NUMBER: _ClassVar[int]
COMPLETE_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
request_id: str
chunk: GenerateStreamChunk
complete: GenerateComplete
error: GenerateError
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
generation_time: float
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
output_text: str
finish_reason: GenerateComplete.FinishReason
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int]
DETAILS_FIELD_NUMBER: _ClassVar[int]
message: str
http_status_code: str
details: str
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
VALUES_FIELD_NUMBER: _ClassVar[int]
LAYER_FIELD_NUMBER: _ClassVar[int]
POSITION_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
layer: int
position: int
def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ...
class EmbedRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int]
TEXTS_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
sampling_params: SamplingParams
log_metrics: bool
token_type_ids: _containers.RepeatedScalarFieldContainer[int]
data_parallel_rank: int
is_cross_encoder: bool
texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ...
class EmbedResponse(_message.Message):
__slots__ = ("request_id", "complete", "error")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
COMPLETE_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
request_id: str
complete: EmbedComplete
error: EmbedError
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")
VALUES_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
index: int
def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ...
class EmbedError(_message.Message):
__slots__ = ("message", "code", "details")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
CODE_FIELD_NUMBER: _ClassVar[int]
DETAILS_FIELD_NUMBER: _ClassVar[int]
message: str
code: str
details: str
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class HealthCheckRequest(_message.Message):
__slots__ = ("tokenized",)
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
tokenized: TokenizedInput
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
class HealthCheckResponse(_message.Message):
__slots__ = ("healthy", "message")
HEALTHY_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
healthy: bool
message: str
def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ...
class AbortRequest(_message.Message):
__slots__ = ("request_id", "reason")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
request_id: str
reason: str
def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ...
class AbortResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class LoadLoRARequest(_message.Message):
__slots__ = ("adapter_id", "adapter_path", "rank")
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int]
RANK_FIELD_NUMBER: _ClassVar[int]
adapter_id: str
adapter_path: str
rank: int
def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ...
class LoadLoRAResponse(_message.Message):
__slots__ = ("success", "adapter_id", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
adapter_id: str
message: str
def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ...
class UnloadLoRARequest(_message.Message):
__slots__ = ("adapter_id",)
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
adapter_id: str
def __init__(self, adapter_id: _Optional[str] = ...) -> None: ...
class UnloadLoRAResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class UpdateWeightsRequest(_message.Message):
__slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name")
DISK_PATH_FIELD_NUMBER: _ClassVar[int]
TENSOR_DATA_FIELD_NUMBER: _ClassVar[int]
REMOTE_URL_FIELD_NUMBER: _ClassVar[int]
WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int]
disk_path: str
tensor_data: bytes
remote_url: str
weight_name: str
def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ...
class UpdateWeightsResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
class GetInternalStateRequest(_message.Message):
__slots__ = ("state_keys",)
STATE_KEYS_FIELD_NUMBER: _ClassVar[int]
state_keys: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ...
class GetInternalStateResponse(_message.Message):
__slots__ = ("state",)
STATE_FIELD_NUMBER: _ClassVar[int]
state: _struct_pb2.Struct
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class SetInternalStateRequest(_message.Message):
__slots__ = ("state",)
STATE_FIELD_NUMBER: _ClassVar[int]
state: _struct_pb2.Struct
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class SetInternalStateResponse(_message.Message):
__slots__ = ("success", "message")
SUCCESS_FIELD_NUMBER: _ClassVar[int]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
success: bool
message: str
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import sglang_scheduler_pb2 as sglang__scheduler__pb2
GRPC_GENERATED_VERSION = '1.74.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in sglang_scheduler_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class SglangSchedulerStub(object):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Generate = channel.unary_stream(
'/sglang.grpc.scheduler.SglangScheduler/Generate',
request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString,
response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString,
_registered_method=True)
self.Embed = channel.unary_unary(
'/sglang.grpc.scheduler.SglangScheduler/Embed',
request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString,
response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString,
_registered_method=True)
self.Abort = channel.unary_unary(
'/sglang.grpc.scheduler.SglangScheduler/Abort',
request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
_registered_method=True)
class SglangSchedulerServicer(object):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
def Generate(self, request, context):
"""Submit a generation request (supports streaming)
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Embed(self, request, context):
"""Submit an embedding request
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def HealthCheck(self, request, context):
"""Health check and metrics
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Abort(self, request, context):
"""Abort a running request
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_SglangSchedulerServicer_to_server(servicer, server):
rpc_method_handlers = {
'Generate': grpc.unary_stream_rpc_method_handler(
servicer.Generate,
request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString,
response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString,
),
'Embed': grpc.unary_unary_rpc_method_handler(
servicer.Embed,
request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString,
response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString,
response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString,
),
'Abort': grpc.unary_unary_rpc_method_handler(
servicer.Abort,
request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class SglangScheduler(object):
"""Service definition for SGLang scheduler communication
This protocol bridges the Rust router and Python scheduler
"""
@staticmethod
def Generate(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/sglang.grpc.scheduler.SglangScheduler/Generate',
sglang__scheduler__pb2.GenerateRequest.SerializeToString,
sglang__scheduler__pb2.GenerateResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Embed(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/sglang.grpc.scheduler.SglangScheduler/Embed',
sglang__scheduler__pb2.EmbedRequest.SerializeToString,
sglang__scheduler__pb2.EmbedResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def HealthCheck(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
sglang__scheduler__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Abort(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/sglang.grpc.scheduler.SglangScheduler/Abort',
sglang__scheduler__pb2.AbortRequest.SerializeToString,
sglang__scheduler__pb2.AbortResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
......@@ -2238,6 +2238,7 @@ class ServerArgs:
args.pp_size = args.pipeline_parallel_size
args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
......
......@@ -37,21 +37,6 @@ impl SglangSchedulerClient {
Ok(Self { client })
}
/// Initialize the connection
pub async fn initialize(
&mut self,
client_id: String,
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::InitializeRequest {
client_id,
client_version: "0.1.0".to_string(),
mode: proto::initialize_request::Mode::Regular as i32,
});
let response = self.client.initialize(request).await?;
Ok(response.into_inner())
}
/// Submit a generation request (returns streaming response)
pub async fn generate_stream(
&mut self,
......@@ -68,7 +53,10 @@ impl SglangSchedulerClient {
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest {
include_detailed_metrics: false,
tokenized: Some(proto::TokenizedInput {
original_text: "Hello".to_string(),
input_ids: vec![9906], // Mock token ID for "Hello"
}),
});
let response = self.client.health_check(request).await?;
......@@ -87,21 +75,6 @@ impl SglangSchedulerClient {
self.client.abort(request).await?;
Ok(())
}
/// Flush cache
pub async fn flush_cache(
&mut self,
flush_all: bool,
session_ids: &[String],
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
let request = Request::new(proto::FlushCacheRequest {
flush_all,
session_ids: session_ids.to_vec(),
});
let response = self.client.flush_cache(request).await?;
Ok(response.into_inner())
}
}
#[cfg(test)]
......@@ -111,14 +84,13 @@ mod tests {
#[test]
fn test_proto_types_compilation() {
// Test that protobuf types can be constructed
let init_req = proto::InitializeRequest {
client_id: "test-client".to_string(),
client_version: "0.1.0".to_string(),
mode: 0,
let health_req = proto::HealthCheckRequest {
tokenized: Some(proto::TokenizedInput {
original_text: "test".to_string(),
input_ids: vec![1296],
}),
};
assert_eq!(init_req.client_id, "test-client");
assert_eq!(init_req.client_version, "0.1.0");
assert_eq!(init_req.mode, 0);
assert!(health_req.tokenized.is_some());
}
#[test]
......@@ -134,9 +106,10 @@ mod tests {
let gen_req = proto::GenerateRequest {
request_id: "test-req-123".to_string(),
input: Some(proto::generate_request::Input::Text(
"Hello world".to_string(),
)),
tokenized: Some(proto::TokenizedInput {
original_text: "Hello world".to_string(),
input_ids: vec![9906, 1917], // Mock token IDs for "Hello world"
}),
sampling_params: Some(sampling_params),
return_logprob: true,
logprob_start_len: 0,
......@@ -145,8 +118,8 @@ mod tests {
};
assert_eq!(gen_req.request_id, "test-req-123");
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
assert_eq!(text, "Hello world");
if let Some(ref tokenized) = &gen_req.tokenized {
assert_eq!(tokenized.original_text, "Hello world");
}
assert!(gen_req.return_logprob);
assert_eq!(gen_req.top_logprobs_num, 5);
......@@ -160,9 +133,12 @@ mod tests {
#[test]
fn test_health_check_request() {
let health_req = proto::HealthCheckRequest {
include_detailed_metrics: true,
tokenized: Some(proto::TokenizedInput {
original_text: "test".to_string(),
input_ids: vec![1296], // Mock token ID for "test"
}),
};
assert!(health_req.include_detailed_metrics);
assert!(health_req.tokenized.is_some());
}
#[test]
......@@ -175,17 +151,6 @@ mod tests {
assert_eq!(abort_req.reason, "User canceled");
}
#[test]
fn test_flush_cache_request() {
let flush_req = proto::FlushCacheRequest {
flush_all: true,
session_ids: vec!["session1".to_string(), "session2".to_string()],
};
assert!(flush_req.flush_all);
assert_eq!(flush_req.session_ids.len(), 2);
assert_eq!(flush_req.session_ids[0], "session1");
}
#[test]
fn test_sampling_params_defaults() {
let params = proto::SamplingParams::default();
......@@ -214,38 +179,29 @@ mod tests {
assert_eq!(mm_inputs.modalities[0], "image");
}
#[test]
fn test_session_params() {
let session_params = proto::SessionParams {
session_id: "sess-789".to_string(),
request_id: "req-101".to_string(),
offset: 100,
replace: true,
drop_previous_output: false,
};
assert_eq!(session_params.session_id, "sess-789");
assert_eq!(session_params.request_id, "req-101");
assert_eq!(session_params.offset, 100);
assert!(session_params.replace);
assert!(!session_params.drop_previous_output);
}
// TODO: SessionParams not in current proto - skip test
// #[test]
// fn test_session_params() { ... }
#[test]
fn test_embed_request() {
let embed_req = proto::EmbedRequest {
request_id: "embed-req-202".to_string(),
input: Some(proto::embed_request::Input::Text(
"This is a test sentence for embedding".to_string(),
)),
tokenized: Some(proto::TokenizedInput {
original_text: "This is a test sentence for embedding".to_string(),
input_ids: vec![2028, 374, 264, 1296, 11914, 369, 28537], // Mock token IDs
}),
log_metrics: true,
data_parallel_rank: 0,
..Default::default()
};
assert_eq!(embed_req.request_id, "embed-req-202");
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
assert_eq!(text, "This is a test sentence for embedding");
if let Some(ref tokenized) = &embed_req.tokenized {
assert_eq!(
tokenized.original_text,
"This is a test sentence for embedding"
);
}
assert!(embed_req.log_metrics);
assert_eq!(embed_req.data_parallel_rank, 0);
......@@ -292,36 +248,7 @@ mod tests {
assert_eq!(chunk.queue_time, 10);
}
#[test]
fn test_model_info() {
let model_info = proto::ModelInfo {
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
max_context_length: 8192,
vocab_size: 128256,
supports_tool_calling: true,
supports_vision: false,
special_tokens: vec![
"<|begin_of_text|>".to_string(),
"<|end_of_text|>".to_string(),
],
model_type: "llama".to_string(),
num_layers: 32,
hidden_size: 4096,
num_attention_heads: 32,
num_key_value_heads: 8,
tokenizer_type: "llama".to_string(),
eos_token_ids: vec![128001, 128009],
pad_token_id: 128001,
bos_token_id: 128000,
};
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
assert_eq!(model_info.max_context_length, 8192);
assert_eq!(model_info.vocab_size, 128256);
assert!(model_info.supports_tool_calling);
assert!(!model_info.supports_vision);
assert_eq!(model_info.special_tokens.len(), 2);
assert_eq!(model_info.num_layers, 32);
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
}
// TODO: ModelInfo not in current proto - skip test
// #[test]
// fn test_model_info() { ... }
}
......@@ -8,9 +8,6 @@ import "google/protobuf/struct.proto";
// Service definition for SGLang scheduler communication
// This protocol bridges the Rust router and Python scheduler
service SglangScheduler {
// Initialize connection and get model info
rpc Initialize(InitializeRequest) returns (InitializeResponse);
// Submit a generation request (supports streaming)
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
......@@ -23,8 +20,6 @@ service SglangScheduler {
// Abort a running request
rpc Abort(AbortRequest) returns (AbortResponse);
// Flush KV cache
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
}
// =====================
......@@ -75,14 +70,6 @@ message SamplingParams {
google.protobuf.Struct custom_params = 25;
}
// Session parameters for continual prompting
message SessionParams {
string session_id = 1;
string request_id = 2;
int32 offset = 3;
bool replace = 4;
bool drop_previous_output = 5;
}
// Disaggregated serving parameters
message DisaggregatedParams {
......@@ -91,87 +78,6 @@ message DisaggregatedParams {
int32 bootstrap_room = 3;
}
// =====================
// Initialize
// =====================
message InitializeRequest {
string client_id = 1;
string client_version = 2;
// Operating mode
enum Mode {
REGULAR = 0; // Normal mode with local scheduler
PREFILL = 1; // Prefill-only mode for disaggregated serving
DECODE = 2; // Decode-only mode for disaggregated serving
}
Mode mode = 3;
}
message InitializeResponse {
bool success = 1;
string scheduler_version = 2;
// Model information
ModelInfo model_info = 3;
// Server capabilities
ServerCapabilities capabilities = 4;
// Error message if success is false
string error_message = 5;
}
message ModelInfo {
string model_name = 1;
int32 max_context_length = 2;
int32 vocab_size = 3;
bool supports_tool_calling = 4;
bool supports_vision = 5;
repeated string special_tokens = 6;
// Additional model metadata
string model_type = 7;
int32 num_layers = 8;
int32 hidden_size = 9;
int32 num_attention_heads = 10;
int32 num_key_value_heads = 11;
// Tokenizer info
string tokenizer_type = 12;
repeated int32 eos_token_ids = 13;
int32 pad_token_id = 14;
int32 bos_token_id = 15;
}
message ServerCapabilities {
bool continuous_batching = 1;
bool disaggregated_serving = 2;
bool speculative_decoding = 3;
int32 max_batch_size = 4;
int32 max_num_batched_tokens = 5;
int32 max_prefill_tokens = 6;
string attention_backend = 7; // "flashinfer", "triton", "torch"
// Additional capabilities
bool supports_lora = 8;
bool supports_grammar = 9;
bool supports_multimodal = 10;
repeated string supported_modalities = 11; // ["image", "video", "audio"]
bool supports_custom_logit_processor = 12;
bool supports_session = 13;
// Hardware info
int32 num_gpus = 14;
string gpu_type = 15;
int64 total_gpu_memory = 16;
// Parallelism info
int32 tensor_parallel_size = 17;
int32 pipeline_parallel_size = 18;
int32 data_parallel_size = 19;
}
// =====================
// Generate Request
// =====================
......@@ -179,49 +85,43 @@ message ServerCapabilities {
message GenerateRequest {
string request_id = 1;
// Input can be either text or tokenized
oneof input {
string text = 2;
TokenizedInput tokenized = 3;
}
// Input must be tokenized (no raw text)
TokenizedInput tokenized = 2;
// Multimodal inputs
MultimodalInputs mm_inputs = 4;
MultimodalInputs mm_inputs = 3;
// Generation parameters
SamplingParams sampling_params = 5;
SamplingParams sampling_params = 4;
// Return options
bool return_logprob = 6;
int32 logprob_start_len = 7;
int32 top_logprobs_num = 8;
repeated int32 token_ids_logprob = 9;
bool return_hidden_states = 10;
// Session management
SessionParams session_params = 11;
bool return_logprob = 5;
int32 logprob_start_len = 6;
int32 top_logprobs_num = 7;
repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9;
// For disaggregated serving
DisaggregatedParams disaggregated_params = 12;
DisaggregatedParams disaggregated_params = 10;
// Custom logit processor (serialized)
string custom_logit_processor = 13;
string custom_logit_processor = 11;
// Request metadata
google.protobuf.Timestamp timestamp = 14;
bool log_metrics = 15;
google.protobuf.Timestamp timestamp = 12;
bool log_metrics = 13;
// Input embeddings (alternative to text/tokens)
repeated float input_embeds = 16;
repeated float input_embeds = 14;
// LoRA adapter ID (if pre-loaded)
string lora_id = 17;
string lora_id = 15;
// Data parallel routing
int32 data_parallel_rank = 18;
int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 19;
int32 dp_balance_id = 17;
}
message TokenizedInput {
......@@ -303,19 +203,6 @@ message GenerateComplete {
}
FinishReason finish_reason = 3;
// Final counts
int32 prompt_tokens = 4;
int32 completion_tokens = 5;
int32 cached_tokens = 6;
// Performance metrics
float total_generation_time = 7;
float time_to_first_token = 8;
float tokens_per_second = 9;
// Spec decode metrics
int32 spec_verify_count = 10;
// All logprobs if requested
repeated LogProbs all_logprobs = 11;
......@@ -359,10 +246,8 @@ message HiddenStates {
message EmbedRequest {
string request_id = 1;
oneof input {
string text = 2;
TokenizedInput tokenized = 3;
}
// Input must be tokenized (no raw text)
TokenizedInput tokenized = 2;
// Multimodal inputs
MultimodalInputs mm_inputs = 4;
......@@ -422,39 +307,13 @@ message EmbedError {
// =====================
message HealthCheckRequest {
bool include_detailed_metrics = 1;
// Input for health test generation (must be tokenized)
TokenizedInput tokenized = 1;
}
message HealthCheckResponse {
bool healthy = 1;
// Current load metrics
int32 num_requests_running = 2;
int32 num_requests_waiting = 3;
float gpu_cache_usage = 4;
float gpu_memory_usage = 5;
// KV cache metrics
int32 kv_cache_total_blocks = 6;
int32 kv_cache_used_blocks = 7;
float kv_cache_hit_rate = 8;
// Additional metrics
int32 num_grammar_queue_requests = 9;
float generation_throughput = 10; // tokens/sec
float average_queue_time = 11; // seconds
float average_generation_time = 12; // seconds
// System metrics
float cpu_usage = 13;
int64 memory_usage = 14;
// Disaggregation metrics
int32 num_prefill_requests = 15;
int32 num_decode_requests = 16;
// Detailed metrics (optional)
google.protobuf.Struct detailed_metrics = 17;
string message = 2;
}
message AbortRequest {
......@@ -467,17 +326,6 @@ message AbortResponse {
string message = 2;
}
message FlushCacheRequest {
bool flush_all = 1;
repeated string session_ids = 2; // Flush specific sessions
}
message FlushCacheResponse {
bool success = 1;
int32 num_entries_flushed = 2;
int64 memory_freed = 3; // bytes
string message = 4;
}
// =====================
// Additional Operations (Future)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment