Unverified Commit 6ade6a02 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[grpc] Support gRPC standard health check (#11955)

parent 983ef22c
...@@ -72,6 +72,7 @@ dependencies = [ ...@@ -72,6 +72,7 @@ dependencies = [
"grpcio==1.75.1", # keep it align with compile_proto.py "grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.75.1", # keep it align with compile_proto.py "grpcio-tools==1.75.1", # keep it align with compile_proto.py
"grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py
"grpcio-health-checking==1.75.1", # required for Kubernetes gRPC health probes
] ]
[project.optional-dependencies] [project.optional-dependencies]
......
...@@ -12,166 +12,35 @@ import signal ...@@ -12,166 +12,35 @@ import signal
import threading import threading
import time import time
from concurrent import futures from concurrent import futures
from typing import AsyncIterator, Dict, Optional, Tuple from typing import AsyncIterator, Dict, Optional
import grpc import grpc
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp from google.protobuf.timestamp_pb2 import Timestamp
from grpc_health.v1 import health_pb2_grpc
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
import sglang import sglang
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.grpc.health_servicer import SGLangHealthServicer
run_data_parallel_controller_process, from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only
)
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import kill_process_tree
configure_logger,
kill_process_tree,
prepare_model_and_tokenizer,
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
def _run_scheduler_with_signal_handling(*args, **kwargs):
"""
Wrapper for run_scheduler_process that ignores SIGINT.
The scheduler process should not handle Ctrl+C - it should only terminate
when the parent gRPC server exits (via kill_itself_when_parent_died).
"""
# Ignore SIGINT in this subprocess - let the parent handle it
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Now run the actual scheduler process
run_scheduler_process(*args, **kwargs)
def _launch_scheduler_process_only(
server_args: ServerArgs,
port_args: Optional[PortArgs] = None,
) -> Tuple[Dict, PortArgs, list]:
"""
Launch only the scheduler process(es) without tokenizer/detokenizer.
Returns scheduler info, port args, and list of scheduler processes.
"""
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp.set_start_method("spawn", force=True)
# Allocate ports for inter-process communications
if port_args is None:
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Prepare model and tokenizer paths
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
server_args.model_path, server_args.tokenizer_path
)
scheduler_procs = []
if server_args.dp_size == 1:
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = []
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process(
target=_run_scheduler_with_signal_handling,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
None,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)
# TODO(CatherineSue): handle cases for multi-node
# Wait for all scheduler processes to be ready
scheduler_infos = []
for i, reader in enumerate(scheduler_pipe_readers):
try:
data = reader.recv()
except EOFError:
logger.error(
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
)
scheduler_procs[i].join()
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
if data.get("status") != "ready":
raise RuntimeError(
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
)
scheduler_infos.append(data)
logger.info(
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
)
# Return the first scheduler's info (they should all be the same)
return scheduler_infos[0], port_args, scheduler_procs
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer): class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
""" """
Standalone gRPC service implementation using GrpcRequestManager. Standalone gRPC service implementation using GrpcRequestManager.
...@@ -184,6 +53,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -184,6 +53,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
server_args: ServerArgs, server_args: ServerArgs,
model_info: Dict, model_info: Dict,
scheduler_info: Dict, scheduler_info: Dict,
health_servicer: Optional[SGLangHealthServicer] = None,
): ):
"""Initialize the standalone gRPC service.""" """Initialize the standalone gRPC service."""
self.request_manager = request_manager self.request_manager = request_manager
...@@ -191,6 +61,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -191,6 +61,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
self.model_info = model_info self.model_info = model_info
self.scheduler_info = scheduler_info self.scheduler_info = scheduler_info
self.start_time = time.time() self.start_time = time.time()
self.health_servicer = health_servicer
# Start the request manager's event loop using auto_create_handle_loop # Start the request manager's event loop using auto_create_handle_loop
self.request_manager.auto_create_handle_loop() self.request_manager.auto_create_handle_loop()
...@@ -817,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -817,6 +688,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"""Shutdown the service.""" """Shutdown the service."""
logger.info("Shutting down gRPC service") logger.info("Shutting down gRPC service")
# Mark health service as NOT_SERVING before shutdown
if self.health_servicer:
self.health_servicer.set_not_serving()
# Shutdown request manager (handles its own tasks) # Shutdown request manager (handles its own tasks)
await self.request_manager.shutdown() await self.request_manager.shutdown()
...@@ -839,7 +714,7 @@ async def serve_grpc( ...@@ -839,7 +714,7 @@ async def serve_grpc(
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...") logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only(
server_args=server_args, server_args=server_args,
) )
...@@ -876,18 +751,27 @@ async def serve_grpc( ...@@ -876,18 +751,27 @@ async def serve_grpc(
], ],
) )
# Add service # Create standard health service (for Kubernetes probes)
health_servicer = SGLangHealthServicer(
request_manager=request_manager,
scheduler_info=scheduler_info,
)
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
# Add SGLang service
servicer = SGLangSchedulerServicer( servicer = SGLangSchedulerServicer(
request_manager=request_manager, request_manager=request_manager,
server_args=server_args, server_args=server_args,
model_info=model_info, model_info=model_info,
scheduler_info=scheduler_info, scheduler_info=scheduler_info,
health_servicer=health_servicer,
) )
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server) sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
# Enable reflection # Enable reflection
SERVICE_NAMES = ( SERVICE_NAMES = (
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name, sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
"grpc.health.v1.Health",
reflection.SERVICE_NAME, reflection.SERVICE_NAME,
) )
reflection.enable_server_reflection(SERVICE_NAMES, server) reflection.enable_server_reflection(SERVICE_NAMES, server)
...@@ -902,7 +786,7 @@ async def serve_grpc( ...@@ -902,7 +786,7 @@ async def serve_grpc(
# Start warmup in a separate thread # Start warmup in a separate thread
warmup_thread = threading.Thread( warmup_thread = threading.Thread(
target=_wait_and_warmup_grpc, target=_wait_and_warmup_grpc,
args=(server_args, None), args=(server_args, None, health_servicer),
) )
warmup_thread.start() warmup_thread.start()
...@@ -1103,6 +987,7 @@ def _execute_grpc_server_warmup( ...@@ -1103,6 +987,7 @@ def _execute_grpc_server_warmup(
def _wait_and_warmup_grpc( def _wait_and_warmup_grpc(
server_args: ServerArgs, server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection], pipe_finish_writer: Optional[mp.connection.Connection],
health_servicer: Optional[SGLangHealthServicer] = None,
): ):
"""Wait for gRPC server to be ready and execute warmup.""" """Wait for gRPC server to be ready and execute warmup."""
if not server_args.skip_server_warmup: if not server_args.skip_server_warmup:
...@@ -1111,6 +996,11 @@ def _wait_and_warmup_grpc( ...@@ -1111,6 +996,11 @@ def _wait_and_warmup_grpc(
else: else:
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)") logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
# Mark health service as SERVING after warmup completes
if health_servicer:
health_servicer.set_serving()
logger.info("Health service marked as SERVING")
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
......
"""
Standard gRPC health check service implementation for Kubernetes probes.
This module implements the grpc.health.v1.Health service protocol, enabling
native Kubernetes gRPC health probes for liveness and readiness checks.
"""
import logging
import time
from typing import AsyncIterator
import grpc
from grpc_health.v1 import health_pb2, health_pb2_grpc
logger = logging.getLogger(__name__)
class SGLangHealthServicer(health_pb2_grpc.HealthServicer):
"""
Standard gRPC health check service implementation for Kubernetes probes.
Implements grpc.health.v1.Health protocol.
Supports two service levels:
1. Overall server health (service="") - for liveness probes
2. SGLang service health (service="sglang.grpc.scheduler.SglangScheduler") - for readiness probes
Health status lifecycle:
- NOT_SERVING: Initial state, model loading, or shutting down
- SERVING: Model loaded and ready to serve requests
"""
# Service names we support
OVERALL_SERVER = "" # Empty string for overall server health
SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler"
def __init__(self, request_manager, scheduler_info: dict):
"""
Initialize health servicer.
Args:
request_manager: GrpcRequestManager instance for checking server state
scheduler_info: Dict containing scheduler metadata
"""
self.request_manager = request_manager
self.scheduler_info = scheduler_info
self._serving_status = {}
# Initially set to NOT_SERVING until model is loaded
self._serving_status[self.OVERALL_SERVER] = (
health_pb2.HealthCheckResponse.NOT_SERVING
)
self._serving_status[self.SGLANG_SERVICE] = (
health_pb2.HealthCheckResponse.NOT_SERVING
)
logger.info("Standard gRPC health service initialized")
def set_serving(self):
"""Mark services as SERVING - call this after model is loaded."""
self._serving_status[self.OVERALL_SERVER] = (
health_pb2.HealthCheckResponse.SERVING
)
self._serving_status[self.SGLANG_SERVICE] = (
health_pb2.HealthCheckResponse.SERVING
)
logger.info("Health service status set to SERVING")
def set_not_serving(self):
"""Mark services as NOT_SERVING - call this during shutdown."""
self._serving_status[self.OVERALL_SERVER] = (
health_pb2.HealthCheckResponse.NOT_SERVING
)
self._serving_status[self.SGLANG_SERVICE] = (
health_pb2.HealthCheckResponse.NOT_SERVING
)
logger.info("Health service status set to NOT_SERVING")
async def Check(
self,
request: health_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> health_pb2.HealthCheckResponse:
"""
Standard health check for Kubernetes probes.
Args:
request: Contains service name ("" for overall, or specific service)
context: gRPC context
Returns:
HealthCheckResponse with SERVING/NOT_SERVING/SERVICE_UNKNOWN status
"""
service_name = request.service
logger.debug(f"Health check request for service: '{service_name}'")
# Check if shutting down
if self.request_manager.gracefully_exit:
logger.debug("Health check: Server is shutting down")
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.NOT_SERVING
)
# Overall server health - just check if process is alive
if service_name == self.OVERALL_SERVER:
status = self._serving_status.get(
self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING
)
logger.debug(
f"Overall health check: {health_pb2.HealthCheckResponse.ServingStatus.Name(status)}"
)
return health_pb2.HealthCheckResponse(status=status)
# Specific service health - check if ready to serve
elif service_name == self.SGLANG_SERVICE:
# Additional checks for service readiness
# Check base status first
base_status = self._serving_status.get(
self.SGLANG_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING
)
if base_status != health_pb2.HealthCheckResponse.SERVING:
logger.debug("Service health check: NOT_SERVING (base status)")
return health_pb2.HealthCheckResponse(status=base_status)
# Check if scheduler is responsive (received data recently)
time_since_last_receive = (
time.time() - self.request_manager.last_receive_tstamp
)
# If no recent activity and we have active requests, might be stuck
# NOTE: 30s timeout is hardcoded. This is more conservative than
# HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC.
# Consider making this configurable via environment variable in the future
# if different workloads need different responsiveness thresholds.
if (
time_since_last_receive > 30
and len(self.request_manager.rid_to_state) > 0
):
logger.warning(
f"Service health check: Scheduler not responsive "
f"({time_since_last_receive:.1f}s since last receive, "
f"{len(self.request_manager.rid_to_state)} pending requests)"
)
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.NOT_SERVING
)
logger.debug("Service health check: SERVING")
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVING
)
# Unknown service
else:
logger.debug(f"Health check for unknown service: '{service_name}'")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Unknown service: {service_name}")
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
)
async def Watch(
self,
request: health_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
"""
Streaming health check - sends updates when status changes.
For now, just send current status once (Kubernetes doesn't use Watch).
A full implementation would monitor status changes and stream updates.
Args:
request: Contains service name
context: gRPC context
Yields:
HealthCheckResponse messages when status changes
"""
service_name = request.service
logger.debug(f"Health watch request for service: '{service_name}'")
# Send current status
response = await self.Check(request, context)
yield response
# Note: Full Watch implementation would monitor status changes
# and stream updates. For K8s probes, Check is sufficient.
"""
Scheduler process management for gRPC server.
This module handles launching and managing scheduler processes for the gRPC server,
including tensor parallelism, pipeline parallelism, and data parallelism configurations.
"""
import logging
import multiprocessing as mp
import signal
from typing import Dict, List, Optional, Tuple
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__)
def run_scheduler_with_signal_handling(*args, **kwargs):
"""
Wrapper for run_scheduler_process that ignores SIGINT.
The scheduler process should not handle Ctrl+C - it should only terminate
when the parent gRPC server exits (via kill_itself_when_parent_died).
Args:
*args: Positional arguments for run_scheduler_process
**kwargs: Keyword arguments for run_scheduler_process
"""
# Ignore SIGINT in this subprocess - let the parent handle it
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Now run the actual scheduler process
run_scheduler_process(*args, **kwargs)
def launch_scheduler_process_only(
server_args: ServerArgs,
port_args: Optional[PortArgs] = None,
) -> Tuple[Dict, PortArgs, List[mp.Process]]:
"""
Launch only the scheduler process(es) without tokenizer/detokenizer.
This function handles all scheduler startup logic including:
- Tensor parallelism (tp_size)
- Pipeline parallelism (pp_size)
- Data parallelism (dp_size)
- Multi-node distributed setup
Args:
server_args: Server configuration
port_args: Port configuration (created if None)
Returns:
Tuple of (scheduler_info, port_args, scheduler_processes):
- scheduler_info: Dict with model metadata and configuration
- port_args: Port configuration used for IPC
- scheduler_processes: List of launched scheduler Process objects
Raises:
RuntimeError: If any scheduler process fails to initialize
"""
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
mp.set_start_method("spawn", force=True)
# Allocate ports for inter-process communications
if port_args is None:
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Prepare model and tokenizer paths
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
server_args.model_path, server_args.tokenizer_path
)
scheduler_procs = []
if server_args.dp_size == 1:
# Single data parallel group - launch TP/PP schedulers
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = []
# Calculate TP/PP distribution across nodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
# Launch scheduler for each TP/PP rank combination
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
# Calculate GPU ID for this rank
gpu_id = (
server_args.base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
# Calculate MoE expert parallel rank
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
# Create scheduler process
proc = mp.Process(
target=run_scheduler_with_signal_handling,
args=(
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
None, # dp_rank
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
else:
# Data parallelism - launch data parallel controller
reader, writer = mp.Pipe(duplex=False)
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)
# TODO(CatherineSue): handle cases for multi-node
# Wait for all scheduler processes to be ready
scheduler_infos = []
for i, reader in enumerate(scheduler_pipe_readers):
try:
data = reader.recv()
except EOFError:
logger.error(
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
)
scheduler_procs[i].join()
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
if data.get("status") != "ready":
raise RuntimeError(
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
)
scheduler_infos.append(data)
logger.info(
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
)
# Return the first scheduler's info (they should all be the same)
return scheduler_infos[0], port_args, scheduler_procs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment