Unverified Commit 32595146 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router][grpc] add warm up to grpc server (#11627)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 86373b9e
...@@ -3,13 +3,13 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server. ...@@ -3,13 +3,13 @@ Standalone gRPC Server for SGLang - Fully separated from HTTP server.
Uses GrpcRequestManager for orchestration without tokenization. Uses GrpcRequestManager for orchestration without tokenization.
""" """
import argparse
import asyncio import asyncio
import dataclasses import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import signal import signal
import threading
import time import time
from concurrent import futures from concurrent import futures
from typing import AsyncIterator, Dict, Optional, Tuple from typing import AsyncIterator, Dict, Optional, Tuple
...@@ -35,7 +35,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,7 +35,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer from sglang.srt.utils import (
configure_logger,
kill_process_tree,
prepare_model_and_tokenizer,
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -884,6 +888,13 @@ async def serve_grpc( ...@@ -884,6 +888,13 @@ async def serve_grpc(
await server.start() await server.start()
logger.info(f"gRPC server listening on {listen_addr}") logger.info(f"gRPC server listening on {listen_addr}")
# Start warmup in a separate thread
warmup_thread = threading.Thread(
target=_wait_and_warmup_grpc,
args=(server_args, None),
)
warmup_thread.start()
# Handle shutdown signals # Handle shutdown signals
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
stop_event = asyncio.Event() stop_event = asyncio.Event()
...@@ -906,6 +917,11 @@ async def serve_grpc( ...@@ -906,6 +917,11 @@ async def serve_grpc(
# Stop the gRPC server # Stop the gRPC server
await server.stop(5.0) await server.stop(5.0)
# Wait for warmup thread to finish
if warmup_thread.is_alive():
logger.info("Waiting for warmup thread to finish...")
warmup_thread.join(timeout=5.0)
# Terminate scheduler processes before exiting to avoid atexit hang # Terminate scheduler processes before exiting to avoid atexit hang
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt # The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
for i, proc in enumerate(scheduler_procs): for i, proc in enumerate(scheduler_procs):
...@@ -921,3 +937,158 @@ async def serve_grpc( ...@@ -921,3 +937,158 @@ async def serve_grpc(
proc.join(timeout=1.0) proc.join(timeout=1.0)
logger.info("All scheduler processes terminated") logger.info("All scheduler processes terminated")
def _execute_grpc_server_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection],
):
"""Execute warmup for gRPC server by checking health and sending test request."""
try:
# Connect to the gRPC server
grpc_url = f"{server_args.host}:{server_args.port}"
channel = grpc.insecure_channel(
grpc_url,
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
],
)
stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel)
# Wait until the server is launched (poll GetModelInfo)
success = False
last_error = None
for _ in range(120):
time.sleep(1)
try:
request = sglang_scheduler_pb2.GetModelInfoRequest()
response = stub.GetModelInfo(request, timeout=5)
success = True
break
except Exception as e:
last_error = str(e)
pass
if not success:
error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
# Get model info to determine if it's generation or embedding
is_generation = response.is_generation
# Send a warmup request
logger.info("Sending warmup request to gRPC server...")
max_new_tokens = 8 if is_generation else 1
if is_generation:
# Create tokenized input for warmup
warmup_request = sglang_scheduler_pb2.GenerateRequest(
request_id=f"WARMUP_{time.time()}",
tokenized=sglang_scheduler_pb2.TokenizedInput(
input_ids=[
954,
15541,
2181,
23496,
1476,
64710,
280,
], # Simple token sequence
original_text="The capital city of France is",
),
sampling_params=sglang_scheduler_pb2.SamplingParams(
temperature=0.0,
max_new_tokens=max_new_tokens,
),
stream=False,
)
# Send the warmup request
try:
responses = list(stub.Generate(warmup_request, timeout=600))
# Check if we got a valid response
if responses and not responses[-1].HasField("error"):
logger.info("gRPC warmup request completed successfully")
success = True
else:
error_msg = (
responses[-1].error.message if responses else "No response"
)
logger.warning(f"gRPC warmup request returned error: {error_msg}")
success = False
except Exception as e:
error_msg = f"gRPC warmup request failed: {e}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
else:
# For embedding models
warmup_request = sglang_scheduler_pb2.EmbedRequest(
request_id=f"WARMUP_{time.time()}",
tokenized=sglang_scheduler_pb2.TokenizedInput(
input_ids=[10, 11, 12],
original_text="test embedding",
),
)
try:
response = stub.Embed(warmup_request, timeout=600)
if not response.HasField("error"):
logger.info("gRPC warmup request completed successfully")
success = True
else:
logger.warning(
f"gRPC warmup request returned error: {response.error.message}"
)
success = False
except Exception as e:
error_msg = f"gRPC warmup request failed: {e}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
channel.close()
return success
except Exception as e:
error_msg = (
f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}"
)
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
try:
channel.close()
except Exception:
pass
kill_process_tree(os.getpid())
return False
def _wait_and_warmup_grpc(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection],
):
"""Wait for gRPC server to be ready and execute warmup."""
if not server_args.skip_server_warmup:
if not _execute_grpc_server_warmup(server_args, pipe_finish_writer):
return
else:
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment