Unverified Commit 9f1787fa authored by xianzhiT's avatar xianzhiT Committed by GitHub
Browse files

Support multi-thread model weight loading (#7277)

parent 8ecad0b1
......@@ -547,6 +547,7 @@ class ModelRunner:
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
......
......@@ -2,6 +2,7 @@
# ruff: noqa: SIM117
import collections
import concurrent
import dataclasses
import fnmatch
import glob
......@@ -11,14 +12,17 @@ import math
import os
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
import huggingface_hub
import numpy as np
import safetensors.torch
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
......@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
set_default_torch_dtype,
)
from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
......@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
get_quant_config,
gguf_quant_weights_iterator,
initialize_dummy_weights,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
......@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
......@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _maybe_download_from_modelscope(
......@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt
)
......@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
)
weights_iterator = safetensors_weights_iterator(
hf_weights_files, disable_mmap=weight_loader_disable_mmap
)
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
disable_mmap=weight_loader_disable_mmap,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files, disable_mmap=weight_loader_disable_mmap
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
......@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config,
)
self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
return model.eval()
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import concurrent.futures
import fnmatch
import glob
import hashlib
import json
import logging
import os
import queue
import tempfile
from collections import defaultdict
from typing import (
......@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
yield name, param
def multi_thread_safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
max_workers: int = 4,
disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if decryption_key:
logger.warning(
"Multi-Thread loading is not working for encrypted safetensor weights."
)
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def pt_weights_iterator(
hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]:
......@@ -471,6 +527,39 @@ def pt_weights_iterator(
del state
def multi_thread_pt_weights_iterator(
hf_weights_files: List[str],
max_workers: int = 4,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(bin_file: str):
return torch.load(bin_file, map_location="cpu", weights_only=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state = future.result()
yield from state.items()
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]:
......
......@@ -47,6 +47,7 @@ class ServerArgs:
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto"
model_loader_extra_config: str = "{}"
trust_remote_code: bool = False
dtype: str = "auto"
kv_cache_dtype: str = "auto"
......@@ -632,6 +633,13 @@ class ServerArgs:
"layer before loading another to make the peak memory envelope "
"smaller.",
)
parser.add_argument(
"--model-loader-extra-config",
type=str,
help="Extra config for model loader. "
"This will be passed to the model loader corresponding to the chosen load_format.",
default=ServerArgs.model_loader_extra_config,
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment