Unverified Commit 1ce4878d authored by wangyu's avatar wangyu Committed by GitHub
Browse files

feat(remote_model): support variable remote backend for model loader (#3964)


Signed-off-by: default avatarwangyu <wangyu.steph@bytedance.com>
parent 977d7cd2
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_remote_state.py \
--model-path /path/to/load \
--tensor-parallel-size 8 \
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
--remote-model-url [protocol]://[host]:[port]/[model_name],
tensor_parallel_size=8,
)
"""
import dataclasses
from argparse import ArgumentParser
from pathlib import Path
from sglang import Engine, ServerArgs
parser = ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--remote-model-save-url",
required=True,
type=str,
help="remote address to store model weights",
)
def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
llm.save_remote_model(url=args.remote_model_save_url)
if __name__ == "__main__":
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.
Example usage:
python save_sharded_state.py \
--model-path /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save
Then, the model can be loaded with
llm = Engine(
model_path="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from argparse import ArgumentParser
from pathlib import Path
from sglang import Engine, ServerArgs
parser = ArgumentParser()
ServerArgs.add_cli_args(parser)
parser.add_argument(
"--output", "-o", required=True, type=str, help="path to output checkpoint"
)
parser.add_argument(
"--file-pattern", type=str, help="string pattern of saved filenames"
)
parser.add_argument(
"--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file",
)
def main(args):
engine_args = ServerArgs.from_cli_args(args)
model_path = engine_args.model_path
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args))
Path(args.output).mkdir(exist_ok=True)
llm.save_sharded_model(
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(
os.path.join(model_path, file), os.path.join(args.output, file)
)
else:
shutil.copy(os.path.join(model_path, file), args.output)
if __name__ == "__main__":
args = parser.parse_args()
main(args)
......@@ -32,6 +32,7 @@ from sglang.lang.choices import (
)
from sglang.utils import LazyImport
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
......@@ -67,6 +68,7 @@ __all__ = [
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
"ServerArgs",
"Anthropic",
"LiteLLM",
"OpenAI",
......
......@@ -22,6 +22,7 @@ class LoadFormat(str, enum.Enum):
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
@dataclass
......
......@@ -51,13 +51,14 @@ class ModelConfig:
self.quantization = quantization
# Parse args
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
model_path,
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
......@@ -318,6 +319,29 @@ class ModelConfig:
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
......
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from sglang.srt.connector.base_connector import (
BaseConnector,
BaseFileConnector,
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type
logger = logging.getLogger(__name__)
class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"
def create_remote_connector(url, device="cpu") -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
else:
raise ValueError(f"Invalid connector type: {url}")
def get_connector_type(client: BaseConnector) -> ConnectorType:
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS
raise ValueError(f"Invalid connector type: {client}")
__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
import torch
class BaseConnector(ABC):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str, device: torch.device = "cpu"):
self.url = url
self.device = device
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
def get_local_dir(self):
return self.local_dir
@abstractmethod
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()
@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()
def close(self):
if self.closed:
return
self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
class BaseKVConnector(BaseConnector):
@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()
@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()
@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()
@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()
class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse
import torch
from sglang.srt.connector import BaseKVConnector
from sglang.srt.connector.serde import create_serde
from sglang.srt.connector.utils import pull_files_from_db
logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str, device: torch.device = "cpu"):
import redis
super().__init__(url, device)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")
# TODO: more serde options
self.s, self.d = create_serde("safe")
def get(self, key: str) -> Optional[torch.Tensor]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return self.d.from_bytes(val)
def getstr(self, key: str) -> Optional[str]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return val.decode("utf-8")
def set(self, key: str, tensor: torch.Tensor) -> None:
assert tensor is not None
self.connection.set(key, self.s.to_bytes(tensor))
def setstr(self, key: str, obj: str) -> None:
self.connection.set(key, obj)
def list(self, prefix: str) -> List[str]:
cursor = 0
all_keys: List[bytes] = []
while True:
ret: Tuple[int, List[bytes]] = self.connection.scan(
cursor=cursor, match=f"{prefix}*"
) # type: ignore
cursor, keys = ret
all_keys.extend(keys)
if cursor == 0:
break
return [key.decode("utf-8") for key in all_keys]
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, bytes], None, None]:
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
for key in keys:
val = self.get(key)
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
yield key, val
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
def close(self):
self.connection.close()
super().close()
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import os
from pathlib import Path
from typing import Generator, Optional, Tuple
import torch
from sglang.srt.connector import BaseFileConnector
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix("s3://").split("/")
prefix = "/".join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj["Key"] for obj in objects.get("Contents", [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Connector(BaseFileConnector):
def __init__(self, url: str) -> None:
import boto3
super().__init__(url)
self.client = boto3.client("s3")
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
bucket_name, _, paths = list_files(
self.client, path=self.url, allow_pattern=allow_pattern
)
return [f"s3://{bucket_name}/{path}" for path in paths]
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(
self.client, self.url, allow_pattern, ignore_pattern
)
if len(files) == 0:
return
for file in files:
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.client.download_file(bucket_name, file, destination_file)
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
from sglang.srt.model_loader.weight_utils import (
runai_safetensors_weights_iterator,
)
# only support safetensor files now
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
return runai_safetensors_weights_iterator(hf_weights_files)
def close(self):
self.client.close()
super().close()
# SPDX-License-Identifier: Apache-2.0
# inspired by LMCache
from typing import Optional, Tuple
import torch
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
from sglang.srt.connector.serde.serde import Deserializer, Serializer
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
s: Optional[Serializer] = None
d: Optional[Deserializer] = None
if serde_type == "safe":
s = SafeSerializer()
d = SafeDeserializer(torch.uint8)
else:
raise ValueError(f"Unknown serde type: {serde_type}")
return s, d
__all__ = [
"Serializer",
"Deserializer",
"SafeSerializer",
"SafeDeserializer",
"create_serde",
]
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
from safetensors.torch import load, save
from sglang.srt.connector.serde.serde import Deserializer, Serializer
class SafeSerializer(Serializer):
def __init__(self):
super().__init__()
def to_bytes(self, t: torch.Tensor) -> bytes:
return save({"tensor_bytes": t.cpu().contiguous()})
class SafeDeserializer(Deserializer):
def __init__(self, dtype):
super().__init__(dtype)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)
# SPDX-License-Identifier: Apache-2.0
import abc
from abc import ABC, abstractmethod
import torch
class Serializer(ABC):
@abstractmethod
def to_bytes(self, t: torch.Tensor) -> bytes:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise NotImplementedError
class Deserializer(metaclass=abc.ABCMeta):
def __init__(self, dtype):
self.dtype = dtype
@abstractmethod
def from_bytes(self, bs: bytes) -> torch.Tensor:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
import os
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from sglang.srt.connector import BaseConnector
def parse_model_name(url: str) -> str:
"""
Parse the model name from the url.
Only used for db connector
"""
parsed_url = urlparse(url)
return parsed_url.path.lstrip("/")
def pull_files_from_db(
connector: BaseConnector,
model_name: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
prefix = f"{model_name}/files/"
local_dir = connector.get_local_dir()
files = connector.list(prefix)
for file in files:
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
with open(destination_file, "wb") as f:
f.write(connector.getstr(file).encode("utf-8"))
......@@ -27,6 +27,9 @@ import signal
import threading
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
import zmq.asyncio
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -57,6 +62,7 @@ from sglang.srt.utils import (
MultiprocessingSerializer,
assert_pkg_version,
configure_logger,
get_zmq_socket,
kill_process_tree,
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
......@@ -102,15 +108,25 @@ class Engine:
# Shutdown the subprocesses automatically when the program exits
atexit.register(self.shutdown)
# Allocate ports for inter-process communications
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Launch subprocesses
tokenizer_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info
context = zmq.Context(2)
self.send_to_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, True
)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
......@@ -350,6 +366,23 @@ class Engine:
self.tokenizer_manager.resume_memory_occupation(obj, None)
)
"""
Execute an RPC call on all scheduler processes.
"""
def collective_rpc(self, method: str, **kwargs):
obj = RpcReqInput(method=method, parameters=kwargs)
self.send_to_rpc.send_pyobj(obj)
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
assert isinstance(recv_req, RpcReqOutput)
assert recv_req.success, recv_req.message
def save_remote_model(self, **kwargs):
self.collective_rpc("save_remote_model", **kwargs)
def save_sharded_model(self, **kwargs):
self.collective_rpc("save_sharded_model", **kwargs)
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
......@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
......@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
_set_envs_and_config(server_args)
# Allocate ports for inter-process communications
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
if port_args is None:
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# If using model from www.modelscope.cn, first download the model.
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
......
......@@ -37,6 +37,8 @@ from sglang.srt.configs import (
MultiModalityConfig,
Qwen2_5_VLConfig,
)
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
......@@ -155,6 +157,14 @@ def get_tokenizer(
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent
if is_remote_url(tokenizer_name):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(tokenizer_name)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = client.get_local_dir()
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
......
......@@ -723,3 +723,15 @@ class SeparateReasoningReqInput:
class VertexGenerateReqInput:
instances: List[dict]
parameters: Optional[dict] = None
@dataclass
class RpcReqInput:
method: str
parameters: Optional[Dict] = None
@dataclass
class RpcReqOutput:
success: bool
message: str
......@@ -32,6 +32,7 @@ import psutil
import setproctitle
import torch
import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
......@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
......@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
......@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
]
)
......@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
else:
recv_reqs = None
......@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
output = self._request_dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
......@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
server_args=global_server_args_dict,
)
def handle_rpc_request(self, recv_req: RpcReqInput):
# Handle RPC requests
logger.info(
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
)
success = True
exec = None
try:
func = getattr(self, recv_req.method)
func(recv_req.parameters)
except Exception as e:
success = False
exec = e
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
if isinstance(self.tp_worker, TpModelWorkerClient):
worker = self.tp_worker.worker
else:
worker = self.tp_worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
if isinstance(self.tp_worker, TpModelWorkerClient):
worker = self.tp_worker.worker
else:
worker = self.tp_worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = []
......
......@@ -1009,6 +1009,22 @@ class ModelRunner:
return False
return rope_scaling.get("type", None) == "mrope"
def save_remote_model(self, url: str):
from sglang.srt.model_loader.loader import RemoteModelLoader
logger.info(f"Saving model to {url}")
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
def save_sharded_model(
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
):
from sglang.srt.model_loader.loader import ShardedStateLoader
logger.info(
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
)
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
......
......@@ -9,6 +9,7 @@ import json
import logging
import math
import os
import time
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
......@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.connector import (
ConnectorType,
create_remote_connector,
get_connector_type,
)
from sglang.srt.connector.utils import parse_model_name
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
set_runai_streamer_env,
)
from sglang.srt.utils import (
get_bool_env_var,
......@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_state.py` for creating a sharded checkpoint.
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
......@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
return model
class RemoteModelLoader(BaseModelLoader):
"""Model loader that can load Tensors from remote database."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# TODO @DellCurry: move to s3 connector only
set_runai_streamer_env(load_config)
def _get_weights_iterator_kv(
self,
client,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights from remote storage."""
assert get_connector_type(client) == ConnectorType.KV
rank = get_tensor_model_parallel_rank()
return client.weight_iterator(rank)
def _get_weights_iterator_fs(
self,
client,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights from remote storage."""
assert get_connector_type(client) == ConnectorType.FS
return client.weight_iterator()
def download_model(self, model_config: ModelConfig) -> None:
pass
@staticmethod
def save_model(
model: torch.nn.Module,
model_path: str,
url: str,
) -> None:
with create_remote_connector(url) as client:
assert get_connector_type(client) == ConnectorType.KV
model_name = parse_model_name(url)
rank = get_tensor_model_parallel_rank()
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for key, tensor in state_dict.items():
r_key = f"{model_name}/keys/rank_{rank}/{key}"
client.set(r_key, tensor)
for root, _, files in os.walk(model_path):
for file_name in files:
# ignore hidden files
if file_name.startswith("."):
continue
if os.path.splitext(file_name)[1] not in (
".bin",
".pt",
".safetensors",
):
file_path = os.path.join(root, file_name)
with open(file_path, encoding="utf-8") as file:
file_content = file.read()
f_key = f"{model_name}/files/{file_name}"
client.setstr(f_key, file_content)
def _load_model_from_remote_kv(self, model: nn.Module, client):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
weights_iterator = self._get_weights_iterator_kv(client)
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for key, tensor in weights_iterator:
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into " "parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
def _load_model_from_remote_fs(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model.load_weights(self._get_weights_iterator_fs(client))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
logger.info("Loading weights from remote storage ...")
start = time.perf_counter()
load_config = self.load_config
assert load_config.load_format == LoadFormat.REMOTE, (
f"Model loader {self.load_config.load_format} is not supported for "
f"load format {load_config.load_format}"
)
model_weights = model_config.model_path
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
with create_remote_connector(model_weights, device_config.device) as client:
connector_type = get_connector_type(client)
if connector_type == ConnectorType.KV:
self._load_model_from_remote_kv(model, client)
elif connector_type == ConnectorType.FS:
self._load_model_from_remote_fs(
model, client, model_config, device_config
)
end = time.perf_counter()
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
......@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.LAYERED:
return LayeredModelLoader(load_config)
if load_config.load_format == LoadFormat.REMOTE:
return RemoteModelLoader(load_config)
return DefaultModelLoader(load_config)
......@@ -585,6 +585,51 @@ def composed_weight_loader(
return composed_loader
def runai_safetensors_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
from runai_model_streamer import SafetensorsStreamer
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
def set_runai_streamer_env(load_config: LoadConfig):
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment