Unverified Commit 0427416b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix zmq binding (#2930)


Co-authored-by: default avatarChunyuan WU <chunyuan.wu@intel.com>
parent bf3edc2c
...@@ -66,7 +66,7 @@ class DataParallelController: ...@@ -66,7 +66,7 @@ class DataParallelController:
self.context = zmq.Context(1 + server_args.dp_size) self.context = zmq.Context(1 + server_args.dp_size)
if server_args.node_rank == 0: if server_args.node_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
self.context, zmq.PULL, port_args.scheduler_input_ipc_name self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
) )
# Dispatch method # Dispatch method
...@@ -93,6 +93,7 @@ class DataParallelController: ...@@ -93,6 +93,7 @@ class DataParallelController:
self.context, self.context,
zmq.PUSH, zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name, dp_port_args[dp_rank].scheduler_input_ipc_name,
True,
) )
def launch_dp_schedulers(self, server_args, port_args): def launch_dp_schedulers(self, server_args, port_args):
......
...@@ -58,10 +58,10 @@ class DetokenizerManager: ...@@ -58,10 +58,10 @@ class DetokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.recv_from_scheduler = get_zmq_socket( self.recv_from_scheduler = get_zmq_socket(
context, zmq.PULL, port_args.detokenizer_ipc_name context, zmq.PULL, port_args.detokenizer_ipc_name, True
) )
self.send_to_tokenizer = get_zmq_socket( self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
......
...@@ -162,21 +162,21 @@ class Scheduler: ...@@ -162,21 +162,21 @@ class Scheduler:
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name context, zmq.PULL, port_args.scheduler_input_ipc_name, False
) )
self.send_to_tokenizer = get_zmq_socket( self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager # Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name context, zmq.PUSH, port_args.tokenizer_ipc_name, False
) )
else: else:
# Send to the DetokenizerManager # Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket( self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name context, zmq.PUSH, port_args.detokenizer_ipc_name, False
) )
else: else:
self.recv_from_tokenizer = None self.recv_from_tokenizer = None
......
...@@ -119,10 +119,10 @@ class TokenizerManager: ...@@ -119,10 +119,10 @@ class TokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = get_zmq_socket( self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name context, zmq.PULL, port_args.tokenizer_ipc_name, True
) )
self.send_to_scheduler = get_zmq_socket( self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
) )
# Read model args # Read model args
......
...@@ -789,7 +789,9 @@ def first_rank_print(*args, **kwargs): ...@@ -789,7 +789,9 @@ def first_rank_print(*args, **kwargs):
pass pass
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): def get_zmq_socket(
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
mem = psutil.virtual_memory() mem = psutil.virtual_memory()
total_mem = mem.total / 1024**3 total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3 available_mem = mem.available / 1024**3
...@@ -802,14 +804,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: ...@@ -802,14 +804,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
if socket_type == zmq.PUSH: if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size) socket.setsockopt(zmq.SNDBUF, buf_size)
socket.connect(endpoint)
elif socket_type == zmq.PULL: elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size) socket.setsockopt(zmq.RCVBUF, buf_size)
socket.bind(endpoint)
else: else:
raise ValueError(f"Unsupported socket type: {socket_type}") raise ValueError(f"Unsupported socket type: {socket_type}")
if bind:
socket.bind(endpoint)
else:
socket.connect(endpoint)
return socket return socket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment