Unverified Commit 96fe2d0f authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add pd service in grpc router for pd (#11120)

parent bfa27438
......@@ -19,6 +19,7 @@ import grpc
import zmq
import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
......@@ -146,11 +147,19 @@ class GrpcRequestManager:
self.crash_dump_request_list = []
self.crash_dump_performed = False
# Bootstrap server for disaggregation mode
self.bootstrap_server = start_disagg_service(server_args)
logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: "
f"recv={port_args.detokenizer_ipc_name}, "
f"send={port_args.scheduler_input_ipc_name}"
)
if self.bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode: "
f"{server_args.disaggregation_mode}"
)
async def generate_request(
self,
......@@ -759,6 +768,22 @@ class GrpcRequestManager:
state.finished = True
state.event.set()
# Wait for tasks to complete
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Shutdown bootstrap server if running
if self.bootstrap_server:
logger.info("Shutting down bootstrap server")
try:
if hasattr(self.bootstrap_server, "shutdown"):
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
await self.bootstrap_server.shutdown()
else:
self.bootstrap_server.shutdown()
except Exception as e:
logger.warning(f"Error shutting down bootstrap server: {e}")
# Close ZMQ sockets
self.recv_from_scheduler.close()
self.send_to_scheduler.close()
......
......@@ -793,6 +793,28 @@ def main():
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
# Disaggregation mode arguments
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default="mooncake",
choices=["mooncake", "nixl", "ascend", "fake"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=8998,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
args = parser.parse_args()
# Convert to ServerArgs with gRPC host and port
......@@ -808,7 +830,9 @@ def main():
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
# Override with gRPC server host and port
disaggregation_mode=args.disaggregation_mode,
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
host=args.host,
port=args.port,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment