"src/vscode:/vscode.git/clone" did not exist on "0a0fe69aa6a11b7723e83ca9e049e6096839ad4d"
Unverified Commit 887c2b45 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Add `serve_grpc` to `launch_server` and log id for HealthCheck (#11564)

parent 065ce815
"""Launch the inference server.""" """Launch the inference server."""
import asyncio
import os import os
import sys import sys
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
...@@ -11,6 +11,13 @@ if __name__ == "__main__": ...@@ -11,6 +11,13 @@ if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
try: try:
launch_server(server_args) if server_args.grpc_mode:
from sglang.srt.entrypoints.grpc_server import serve_grpc
asyncio.run(serve_grpc(server_args))
else:
from sglang.srt.entrypoints.http_server import launch_server
launch_server(server_args)
finally: finally:
kill_process_tree(os.getpid(), include_parent=False) kill_process_tree(os.getpid(), include_parent=False)
...@@ -22,8 +22,8 @@ from grpc_reflection.v1alpha import reflection ...@@ -22,8 +22,8 @@ from grpc_reflection.v1alpha import reflection
import sglang import sglang
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process, run_data_parallel_controller_process,
) )
...@@ -68,6 +68,8 @@ def _launch_scheduler_process_only( ...@@ -68,6 +68,8 @@ def _launch_scheduler_process_only(
# Configure global environment # Configure global environment
configure_logger(server_args) configure_logger(server_args)
server_args.check_server_args() server_args.check_server_args()
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp.set_start_method("spawn", force=True)
# Allocate ports for inter-process communications # Allocate ports for inter-process communications
if port_args is None: if port_args is None:
...@@ -317,7 +319,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -317,7 +319,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
Check the health of the inference server by sending a special request to generate one token. Check the health of the inference server by sending a special request to generate one token.
Similar to HTTP server's /health endpoint. Similar to HTTP server's /health endpoint.
""" """
logger.info("Receive health check request") rid = f"HEALTH_CHECK_{time.time()}"
logger.info(f"Receive health check request: {rid}")
if self.request_manager.gracefully_exit: if self.request_manager.gracefully_exit:
logger.info( logger.info(
...@@ -328,7 +331,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -328,7 +331,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) )
# Create a special health check request # Create a special health check request
rid = f"HEALTH_CHECK_{time.time()}"
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0) sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
sampling_params.normalize(tokenizer=None) sampling_params.normalize(tokenizer=None)
...@@ -919,25 +921,3 @@ async def serve_grpc( ...@@ -919,25 +921,3 @@ async def serve_grpc(
proc.join(timeout=1.0) proc.join(timeout=1.0)
logger.info("All scheduler processes terminated") logger.info("All scheduler processes terminated")
def main():
"""Main entry point for standalone gRPC server."""
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Run server
asyncio.run(
serve_grpc(
server_args=server_args,
)
)
if __name__ == "__main__":
main()
...@@ -326,10 +326,7 @@ message EmbedError { ...@@ -326,10 +326,7 @@ message EmbedError {
// Management Operations // Management Operations
// ===================== // =====================
message HealthCheckRequest { message HealthCheckRequest {}
// Input for health test generation (must be tokenized)
TokenizedInput tokenized = 1;
}
message HealthCheckResponse { message HealthCheckResponse {
bool healthy = 1; bool healthy = 1;
......
...@@ -320,10 +320,8 @@ class EmbedError(_message.Message): ...@@ -320,10 +320,8 @@ class EmbedError(_message.Message):
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class HealthCheckRequest(_message.Message): class HealthCheckRequest(_message.Message):
__slots__ = ("tokenized",) __slots__ = ()
TOKENIZED_FIELD_NUMBER: _ClassVar[int] def __init__(self) -> None: ...
tokenized: TokenizedInput
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
class HealthCheckResponse(_message.Message): class HealthCheckResponse(_message.Message):
__slots__ = ("healthy", "message") __slots__ = ("healthy", "message")
......
...@@ -194,6 +194,7 @@ class ServerArgs: ...@@ -194,6 +194,7 @@ class ServerArgs:
# HTTP server # HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
grpc_mode: bool = False
skip_server_warmup: bool = False skip_server_warmup: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
nccl_port: Optional[int] = None nccl_port: Optional[int] = None
...@@ -1516,6 +1517,11 @@ class ServerArgs: ...@@ -1516,6 +1517,11 @@ class ServerArgs:
default=ServerArgs.port, default=ServerArgs.port,
help="The port of the HTTP server.", help="The port of the HTTP server.",
) )
parser.add_argument(
"--grpc-mode",
action="store_true",
help="If set, use gRPC server instead of HTTP server.",
)
parser.add_argument( parser.add_argument(
"--skip-server-warmup", "--skip-server-warmup",
action="store_true", action="store_true",
......
...@@ -169,8 +169,8 @@ impl SglangSchedulerClient { ...@@ -169,8 +169,8 @@ impl SglangSchedulerClient {
&self, &self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
debug!("Sending health check request"); debug!("Sending health check request");
// Server ignores the request body and creates its own health check internally // HealthCheckRequest is now empty - server generates its own health check internally
let request = Request::new(proto::HealthCheckRequest { tokenized: None }); let request = Request::new(proto::HealthCheckRequest {});
let mut client = self.client.clone(); let mut client = self.client.clone();
let response = client.health_check(request).await?; let response = client.health_check(request).await?;
...@@ -510,13 +510,8 @@ mod tests { ...@@ -510,13 +510,8 @@ mod tests {
#[test] #[test]
fn test_proto_types_compilation() { fn test_proto_types_compilation() {
let health_req = proto::HealthCheckRequest { let _health_req = proto::HealthCheckRequest {};
tokenized: Some(proto::TokenizedInput { // HealthCheckRequest is now empty - no fields to test
original_text: "test".to_string(),
input_ids: vec![1296],
}),
};
assert!(health_req.tokenized.is_some());
} }
#[test] #[test]
...@@ -558,13 +553,8 @@ mod tests { ...@@ -558,13 +553,8 @@ mod tests {
#[test] #[test]
fn test_health_check_request() { fn test_health_check_request() {
let health_req = proto::HealthCheckRequest { let _health_req = proto::HealthCheckRequest {};
tokenized: Some(proto::TokenizedInput { // HealthCheckRequest is now empty - server generates its own test internally
original_text: "test".to_string(),
input_ids: vec![1296], // Mock token ID for "test"
}),
};
assert!(health_req.tokenized.is_some());
} }
#[test] #[test]
......
...@@ -326,10 +326,7 @@ message EmbedError { ...@@ -326,10 +326,7 @@ message EmbedError {
// Management Operations // Management Operations
// ===================== // =====================
message HealthCheckRequest { message HealthCheckRequest {}
// Input for health test generation (must be tokenized)
TokenizedInput tokenized = 1;
}
message HealthCheckResponse { message HealthCheckResponse {
bool healthy = 1; bool healthy = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment