Commit d1ea1295 authored by zhuwenwen's avatar zhuwenwen
Browse files

update utils.py

parent 606e3c92
......@@ -62,12 +62,12 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
import json
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
import json
# Exception strings for non-implemented encoder/decoder scenarios
......@@ -2330,6 +2330,8 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
......@@ -2348,16 +2350,24 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
if bind is None:
bind = socket_type != zmq.PUSH
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
if bind:
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
socket.connect(path)
return socket
......@@ -2366,14 +2376,19 @@ def make_zmq_socket(
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: Optional[bool] = None,
linger: int = 0,
identity: Optional[bytes] = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, socket_type)
yield make_zmq_socket(ctx,
path,
socket_type,
bind=bind,
identity=identity)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment