Unverified Commit d736e0b6 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add grpc router pd mode for chat and generate (#11140)

parent ffd03a9b
...@@ -67,8 +67,8 @@ dependencies = [ ...@@ -67,8 +67,8 @@ dependencies = [
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.24", "xgrammar==0.1.24",
"grpcio==1.74.0", # keep it align with compile_proto.py "grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.74.0" # keep it align with compile_proto.py "grpcio-tools==1.75.1" # keep it align with compile_proto.py
] ]
[project.optional-dependencies] [project.optional-dependencies]
......
...@@ -19,7 +19,6 @@ import grpc ...@@ -19,7 +19,6 @@ import grpc
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOutput, BatchEmbeddingOutput,
...@@ -111,6 +110,7 @@ class GrpcRequestManager: ...@@ -111,6 +110,7 @@ class GrpcRequestManager:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
bootstrap_server=None,
): ):
"""Initialize the gRPC request manager.""" """Initialize the gRPC request manager."""
self.server_args = server_args self.server_args = server_args
...@@ -147,8 +147,8 @@ class GrpcRequestManager: ...@@ -147,8 +147,8 @@ class GrpcRequestManager:
self.crash_dump_request_list = [] self.crash_dump_request_list = []
self.crash_dump_performed = False self.crash_dump_performed = False
# Bootstrap server for disaggregation mode # Bootstrap server (passed from serve_grpc, not started here)
self.bootstrap_server = start_disagg_service(server_args) self.bootstrap_server = bootstrap_server
logger.info( logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: " f"GrpcRequestManager initialized with ZMQ IPC: "
...@@ -157,7 +157,7 @@ class GrpcRequestManager: ...@@ -157,7 +157,7 @@ class GrpcRequestManager:
) )
if self.bootstrap_server: if self.bootstrap_server:
logger.info( logger.info(
f"Bootstrap server started for disaggregation mode: " f"Bootstrap server initialized for disaggregation mode: "
f"{server_args.disaggregation_mode}" f"{server_args.disaggregation_mode}"
) )
......
...@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple ...@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
import grpc import grpc
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process, run_data_parallel_controller_process,
) )
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=None, token_ids_logprob=None,
) )
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_request.bootstrap_room = 0
logger.info(f"Sending health check request to request manager...") logger.info(f"Sending health check request to request manager...")
# Submit and wait for response # Submit and wait for response
...@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert sampling params # Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params) sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
# Extract disaggregated params if present
bootstrap_host = None
bootstrap_port = None
bootstrap_room = None
if grpc_req.HasField("disaggregated_params"):
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
# Create request # Create request
return TokenizedGenerateReqInput( return TokenizedGenerateReqInput(
rid=grpc_req.request_id, rid=grpc_req.request_id,
...@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=( token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
), ),
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
) )
def _convert_embed_request( def _convert_embed_request(
...@@ -659,6 +677,16 @@ async def serve_grpc( ...@@ -659,6 +677,16 @@ async def serve_grpc(
): ):
"""Start the standalone gRPC server with integrated scheduler.""" """Start the standalone gRPC server with integrated scheduler."""
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
# This ensures the bootstrap server is ready when prefill schedulers try to register
bootstrap_server = None
if server_args.disaggregation_mode == "prefill":
bootstrap_server = start_disagg_service(server_args)
if bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
)
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...") logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
...@@ -682,9 +710,11 @@ async def serve_grpc( ...@@ -682,9 +710,11 @@ async def serve_grpc(
} }
# Create request manager with the correct port args # Create request manager with the correct port args
# Note: We pass None for bootstrap_server since it's already started above
request_manager = GrpcRequestManager( request_manager = GrpcRequestManager(
server_args=server_args, server_args=server_args,
port_args=port_args, port_args=port_args,
bootstrap_server=bootstrap_server,
) )
# Create gRPC server # Create gRPC server
...@@ -764,79 +794,9 @@ def main(): ...@@ -764,79 +794,9 @@ def main():
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server") parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
ServerArgs.add_cli_args(parser)
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
# Model arguments
parser.add_argument("--model-path", type=str, required=True, help="Model path")
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
parser.add_argument("--context-length", type=int, help="Context length")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
# Runtime arguments
parser.add_argument(
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
)
parser.add_argument(
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
)
parser.add_argument(
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
)
parser.add_argument(
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
)
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
# Disaggregation mode arguments
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default="mooncake",
choices=["mooncake", "nixl", "ascend", "fake"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=8998,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Convert to ServerArgs with gRPC host and port
server_args = ServerArgs(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path or args.model_path,
context_length=args.context_length,
tp_size=args.tp_size,
dp_size=args.dp_size,
max_running_requests=args.max_running_requests,
max_total_tokens=args.max_total_tokens,
max_prefill_tokens=args.max_prefill_tokens,
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
disaggregation_mode=args.disaggregation_mode,
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
host=args.host,
port=args.port,
)
# Run server # Run server
asyncio.run( asyncio.run(
......
...@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's connection mode (HTTP or gRPC) /// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode; fn connection_mode(&self) -> ConnectionMode;
/// Get the bootstrap hostname for PD mode
/// Returns cached hostname parsed from URL at construction time
fn bootstrap_host(&self) -> &str {
&self.metadata().bootstrap_host
}
/// Get the bootstrap port for PD mode
/// Returns cached port from WorkerType::Prefill
fn bootstrap_port(&self) -> Option<u16> {
self.metadata().bootstrap_port
}
/// Check if the worker is currently healthy /// Check if the worker is currently healthy
fn is_healthy(&self) -> bool; fn is_healthy(&self) -> bool;
...@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
true true
} }
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves /// Get the model ID this worker serves
fn model_id(&self) -> &str { fn model_id(&self) -> &str {
self.metadata() self.metadata()
...@@ -325,6 +322,10 @@ pub struct WorkerMetadata { ...@@ -325,6 +322,10 @@ pub struct WorkerMetadata {
pub health_config: HealthConfig, pub health_config: HealthConfig,
/// API key /// API key
pub api_key: Option<String>, pub api_key: Option<String>,
/// Cached bootstrap hostname (parsed from URL at construction time)
pub bootstrap_host: String,
/// Cached bootstrap port (from WorkerType::Prefill)
pub bootstrap_port: Option<u16>,
} }
/// Basic worker implementation /// Basic worker implementation
......
...@@ -96,12 +96,29 @@ impl BasicWorkerBuilder { ...@@ -96,12 +96,29 @@ impl BasicWorkerBuilder {
/// Build the BasicWorker instance /// Build the BasicWorker instance
pub fn build(self) -> BasicWorker { pub fn build(self) -> BasicWorker {
use std::borrow::Cow;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, AtomicUsize}, atomic::{AtomicBool, AtomicUsize},
Arc, Arc,
}; };
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
let url_to_parse = if self.url.contains("://") {
Cow::from(&self.url)
} else {
Cow::from(format!("http://{}", self.url))
};
let bootstrap_host = match url::Url::parse(&url_to_parse) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
Err(_) => "localhost".to_string(),
};
let bootstrap_port = match self.worker_type {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let metadata = WorkerMetadata { let metadata = WorkerMetadata {
url: self.url.clone(), url: self.url.clone(),
api_key: self.api_key, api_key: self.api_key,
...@@ -109,6 +126,8 @@ impl BasicWorkerBuilder { ...@@ -109,6 +126,8 @@ impl BasicWorkerBuilder {
connection_mode: self.connection_mode, connection_mode: self.connection_mode,
labels: self.labels, labels: self.labels,
health_config: self.health_config, health_config: self.health_config,
bootstrap_host,
bootstrap_port,
}; };
let grpc_client = Arc::new(RwLock::new( let grpc_client = Arc::new(RwLock::new(
......
...@@ -342,6 +342,12 @@ impl SglangSchedulerClient { ...@@ -342,6 +342,12 @@ impl SglangSchedulerClient {
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?; .map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
} }
// Handle n with conversion
if let Some(n) = p.n {
sampling.n = i32::try_from(n)
.map_err(|_| "n must fit into a 32-bit signed integer".to_string())?;
}
// Handle constraints (exactly one allowed) // Handle constraints (exactly one allowed)
sampling.constraint = Self::build_single_constraint_from_plain(p)?; sampling.constraint = Self::build_single_constraint_from_plain(p)?;
......
...@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize}; ...@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value}; use serde_json::{to_value, Map, Number, Value};
use std::collections::HashMap; use std::collections::HashMap;
// Default model value when not specified
fn default_model() -> String {
"unknown".to_string()
}
// # Protocol Specifications // # Protocol Specifications
// //
// This module contains all protocol definitions for OpenAI and SGLang APIs. // This module contains all protocol definitions for OpenAI and SGLang APIs.
...@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest { ...@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest {
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
/// ID of the model to use /// ID of the model to use
#[serde(default = "default_model")]
pub model: String, pub model: String,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
......
//! gRPC router implementations //! gRPC router implementations
use crate::grpc_client::proto;
use crate::protocols::spec::StringOrArray;
pub mod pd_router; pub mod pd_router;
pub mod router; pub mod router;
pub mod utils;
/// Processed chat messages ready for gRPC generation
#[derive(Debug)]
pub struct ProcessedMessages {
pub text: String,
pub multimodal_inputs: Option<proto::MultimodalInputs>,
pub stop_sequences: Option<StringOrArray>,
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment