Unverified Commit d736e0b6 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add grpc router pd mode for chat and generate (#11140)

parent ffd03a9b
......@@ -67,8 +67,8 @@ dependencies = [
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
"grpcio==1.74.0", # keep it align with compile_proto.py
"grpcio-tools==1.74.0" # keep it align with compile_proto.py
"grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.75.1" # keep it align with compile_proto.py
]
[project.optional-dependencies]
......
......@@ -19,7 +19,6 @@ import grpc
import zmq
import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOutput,
......@@ -111,6 +110,7 @@ class GrpcRequestManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
bootstrap_server=None,
):
"""Initialize the gRPC request manager."""
self.server_args = server_args
......@@ -147,8 +147,8 @@ class GrpcRequestManager:
self.crash_dump_request_list = []
self.crash_dump_performed = False
# Bootstrap server for disaggregation mode
self.bootstrap_server = start_disagg_service(server_args)
# Bootstrap server (passed from serve_grpc, not started here)
self.bootstrap_server = bootstrap_server
logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: "
......@@ -157,7 +157,7 @@ class GrpcRequestManager:
)
if self.bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode: "
f"Bootstrap server initialized for disaggregation mode: "
f"{server_args.disaggregation_mode}"
)
......
......@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
import grpc
from grpc_reflection.v1alpha import reflection
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=None,
)
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_request.bootstrap_room = 0
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
......@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
# Extract disaggregated params if present
bootstrap_host = None
bootstrap_port = None
bootstrap_room = None
if grpc_req.HasField("disaggregated_params"):
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
# Create request
return TokenizedGenerateReqInput(
rid=grpc_req.request_id,
......@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
)
def _convert_embed_request(
......@@ -659,6 +677,16 @@ async def serve_grpc(
):
"""Start the standalone gRPC server with integrated scheduler."""
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
# This ensures the bootstrap server is ready when prefill schedulers try to register
bootstrap_server = None
if server_args.disaggregation_mode == "prefill":
bootstrap_server = start_disagg_service(server_args)
if bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
)
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
......@@ -682,9 +710,11 @@ async def serve_grpc(
}
# Create request manager with the correct port args
# Note: We pass None for bootstrap_server since it's already started above
request_manager = GrpcRequestManager(
server_args=server_args,
port_args=port_args,
bootstrap_server=bootstrap_server,
)
# Create gRPC server
......@@ -764,79 +794,9 @@ def main():
mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
# Model arguments
parser.add_argument("--model-path", type=str, required=True, help="Model path")
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
parser.add_argument("--context-length", type=int, help="Context length")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
# Runtime arguments
parser.add_argument(
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
)
parser.add_argument(
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
)
parser.add_argument(
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
)
parser.add_argument(
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
)
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
# Disaggregation mode arguments
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default="mooncake",
choices=["mooncake", "nixl", "ascend", "fake"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=8998,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
# Convert to ServerArgs with gRPC host and port
server_args = ServerArgs(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path or args.model_path,
context_length=args.context_length,
tp_size=args.tp_size,
dp_size=args.dp_size,
max_running_requests=args.max_running_requests,
max_total_tokens=args.max_total_tokens,
max_prefill_tokens=args.max_prefill_tokens,
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
disaggregation_mode=args.disaggregation_mode,
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
host=args.host,
port=args.port,
)
server_args = ServerArgs.from_cli_args(args)
# Run server
asyncio.run(
......
......@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode;
/// Get the bootstrap hostname for PD mode
/// Returns cached hostname parsed from URL at construction time
fn bootstrap_host(&self) -> &str {
&self.metadata().bootstrap_host
}
/// Get the bootstrap port for PD mode
/// Returns cached port from WorkerType::Prefill
fn bootstrap_port(&self) -> Option<u16> {
self.metadata().bootstrap_port
}
/// Check if the worker is currently healthy
fn is_healthy(&self) -> bool;
......@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
true
}
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves
fn model_id(&self) -> &str {
self.metadata()
......@@ -325,6 +322,10 @@ pub struct WorkerMetadata {
pub health_config: HealthConfig,
/// API key
pub api_key: Option<String>,
/// Cached bootstrap hostname (parsed from URL at construction time)
pub bootstrap_host: String,
/// Cached bootstrap port (from WorkerType::Prefill)
pub bootstrap_port: Option<u16>,
}
/// Basic worker implementation
......
......@@ -96,12 +96,29 @@ impl BasicWorkerBuilder {
/// Build the BasicWorker instance
pub fn build(self) -> BasicWorker {
use std::borrow::Cow;
use std::sync::{
atomic::{AtomicBool, AtomicUsize},
Arc,
};
use tokio::sync::{Mutex, RwLock};
let url_to_parse = if self.url.contains("://") {
Cow::from(&self.url)
} else {
Cow::from(format!("http://{}", self.url))
};
let bootstrap_host = match url::Url::parse(&url_to_parse) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
Err(_) => "localhost".to_string(),
};
let bootstrap_port = match self.worker_type {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let metadata = WorkerMetadata {
url: self.url.clone(),
api_key: self.api_key,
......@@ -109,6 +126,8 @@ impl BasicWorkerBuilder {
connection_mode: self.connection_mode,
labels: self.labels,
health_config: self.health_config,
bootstrap_host,
bootstrap_port,
};
let grpc_client = Arc::new(RwLock::new(
......
......@@ -342,6 +342,12 @@ impl SglangSchedulerClient {
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
}
// Handle n with conversion
if let Some(n) = p.n {
sampling.n = i32::try_from(n)
.map_err(|_| "n must fit into a 32-bit signed integer".to_string())?;
}
// Handle constraints (exactly one allowed)
sampling.constraint = Self::build_single_constraint_from_plain(p)?;
......
......@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value};
use std::collections::HashMap;
// Default model value when not specified
fn default_model() -> String {
"unknown".to_string()
}
// # Protocol Specifications
//
// This module contains all protocol definitions for OpenAI and SGLang APIs.
......@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest {
pub messages: Vec<ChatMessage>,
/// ID of the model to use
#[serde(default = "default_model")]
pub model: String,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
......
//! gRPC router implementations
use crate::grpc_client::proto;
use crate::protocols::spec::StringOrArray;
pub mod pd_router;
pub mod router;
pub mod utils;
/// Processed chat messages ready for gRPC generation
#[derive(Debug)]
pub struct ProcessedMessages {
pub text: String,
pub multimodal_inputs: Option<proto::MultimodalInputs>,
pub stop_sequences: Option<StringOrArray>,
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment