Unverified Commit 1e890341 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix possible ZMQ hanging (#1800)

parent 715b16c1
...@@ -30,6 +30,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process ...@@ -30,6 +30,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
get_zmq_socket,
kill_parent_process, kill_parent_process,
suppress_other_loggers, suppress_other_loggers,
) )
...@@ -66,8 +67,9 @@ class DataParallelController: ...@@ -66,8 +67,9 @@ class DataParallelController:
# Init inter-process communication # Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size) self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL) self.recv_from_tokenizer = get_zmq_socket(
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") self.context, zmq.PULL, port_args.scheduler_input_ipc_name
)
# Dispatch method # Dispatch method
self.round_robin_counter = 0 self.round_robin_counter = 0
...@@ -120,8 +122,9 @@ class DataParallelController: ...@@ -120,8 +122,9 @@ class DataParallelController:
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
send_to = self.context.socket(zmq.PUSH) send_to = get_zmq_socket(
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}") self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Wait for model to finish loading # Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
......
...@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,11 +59,12 @@ class DetokenizerManager: ...@@ -59,11 +59,12 @@ class DetokenizerManager:
): ):
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.recv_from_scheduler = context.socket(zmq.PULL) self.recv_from_scheduler = get_zmq_socket(
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}") context, zmq.PULL, port_args.detokenizer_ipc_name
)
self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer = get_zmq_socket(
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}") context, zmq.PUSH, port_args.tokenizer_ipc_name
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None
......
...@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs ...@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
get_zmq_socket,
is_generation_model, is_generation_model,
is_multimodal_model, is_multimodal_model,
kill_parent_process, kill_parent_process,
...@@ -110,20 +111,19 @@ class Scheduler: ...@@ -110,20 +111,19 @@ class Scheduler:
context = zmq.Context(2) context = zmq.Context(2)
if self.tp_rank == 0: if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = get_zmq_socket(
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") context, zmq.PULL, port_args.scheduler_input_ipc_name
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the tokenizer/api # Directly send to the tokenizer/api
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = get_zmq_socket(
self.send_to_detokenizer.connect( context, zmq.PUSH, port_args.tokenizer_ipc_name
f"ipc://{port_args.tokenizer_ipc_name}"
) )
else: else:
# Send to the detokenizer # Send to the detokenizer
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = get_zmq_socket(
self.send_to_detokenizer.connect( context, zmq.PUSH, port_args.detokenizer_ipc_name
f"ipc://{port_args.detokenizer_ipc_name}"
) )
else: else:
self.recv_from_tokenizer = None self.recv_from_tokenizer = None
......
...@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -86,11 +86,12 @@ class TokenizerManager: ...@@ -86,11 +86,12 @@ class TokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = get_zmq_socket(
self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}") context, zmq.PULL, port_args.tokenizer_ipc_name
)
self.send_to_scheduler = context.socket(zmq.PUSH) self.send_to_scheduler = get_zmq_socket(
self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}") context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
......
...@@ -35,6 +35,7 @@ import psutil ...@@ -35,6 +35,7 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import zmq
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from torch import nn from torch import nn
...@@ -720,3 +721,19 @@ def first_rank_print(*args, **kwargs): ...@@ -720,3 +721,19 @@ def first_rank_print(*args, **kwargs):
print(*args, **kwargs) print(*args, **kwargs)
else: else:
pass pass
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
socket = context.socket(socket_type)
if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, 100000000)
socket.connect(f"ipc://{endpoint}")
elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, 100000000)
socket.bind(f"ipc://{endpoint}")
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
return socket
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment