Unverified Commit 1ce4878d authored by wangyu's avatar wangyu Committed by GitHub
Browse files

feat(remote_model): support variable remote backend for model loader (#3964)


Signed-off-by: default avatarwangyu <wangyu.steph@bytedance.com>
parent 977d7cd2
......@@ -30,6 +30,7 @@ from sglang.srt.utils import (
is_flashinfer_available,
is_hip,
is_port_available,
is_remote_url,
is_valid_ipv6_address,
nullable_str,
)
......@@ -296,6 +297,9 @@ class ServerArgs:
) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf"
if is_remote_url(self.model_path):
self.load_format = "remote"
# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16
......@@ -345,9 +349,11 @@ class ServerArgs:
"safetensors",
"npcache",
"dummy",
"sharded_state",
"gguf",
"bitsandbytes",
"layered",
"remote",
],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
......@@ -1088,6 +1094,9 @@ class PortArgs:
# The port for nccl initialization (torch.dist)
nccl_port: int
# The ipc filename for rpc call between Engine and Scheduler
rpc_ipc_name: str
@staticmethod
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
port = server_args.port + random.randint(100, 1000)
......@@ -1106,6 +1115,7 @@ class PortArgs:
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
nccl_port=port,
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
)
else:
# DP attention. Use TCP + port to handle both single-node and multi-node.
......@@ -1131,6 +1141,7 @@ class PortArgs:
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
nccl_port=port,
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
)
......
......@@ -42,6 +42,7 @@ from importlib.util import find_spec
from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
import numpy as np
......@@ -774,12 +775,22 @@ def get_zmq_socket(
buf_size = -1
socket = context.socket(socket_type)
if socket_type == zmq.PUSH:
def set_send_opt():
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
elif socket_type == zmq.PULL:
def set_recv_opt():
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
if socket_type == zmq.PUSH:
set_send_opt()
elif socket_type == zmq.PULL:
set_recv_opt()
elif socket_type == zmq.DEALER:
set_send_opt()
set_recv_opt()
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
......@@ -1572,3 +1583,29 @@ def add_prefix(name: str, prefix: str) -> str:
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
"""
return name if not prefix else f"{prefix}.{name}"
def is_remote_url(url: Union[str, Path]) -> bool:
"""
Check if the URL is a remote URL of the format:
<connector_type>://<host>:<port>/<model_name>
"""
if isinstance(url, Path):
return False
pattern = r"(.+)://(.*)"
m = re.match(pattern, url)
return m is not None
def parse_connector_type(url: str) -> str:
"""
Parse the connector type from the URL of the format:
<connector_type>://<path>
"""
pattern = r"(.+)://(.*)"
m = re.match(pattern, url)
if m is None:
return ""
return m.group(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment