Unverified Commit 9f1787fa authored by xianzhiT's avatar xianzhiT Committed by GitHub
Browse files

Support multi-thread model weight loading (#7277)

parent 8ecad0b1
...@@ -547,6 +547,7 @@ class ModelRunner: ...@@ -547,6 +547,7 @@ class ModelRunner:
self.load_config = LoadConfig( self.load_config = LoadConfig(
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
) )
if self.server_args.load_format == "gguf": if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import concurrent
import dataclasses import dataclasses
import fnmatch import fnmatch
import glob import glob
...@@ -11,14 +12,17 @@ import math ...@@ -11,14 +12,17 @@ import math
import os import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
import safetensors.torch
import torch import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
...@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import ( ...@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
set_default_torch_dtype, set_default_torch_dtype,
) )
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_duplicate_safetensors_files,
...@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
get_quant_config, get_quant_config,
gguf_quant_weights_iterator, gguf_quant_weights_iterator,
initialize_dummy_weights, initialize_dummy_weights,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator, np_cache_weights_iterator,
pt_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator, safetensors_weights_iterator,
...@@ -181,6 +188,9 @@ class BaseModelLoader(ABC): ...@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk.""" """Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass @dataclasses.dataclass
class Source: class Source:
"""A source for weights.""" """A source for weights."""
...@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config: extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError( raise ValueError(
f"Model loader extra config is not supported for " f"Unexpected extra config keys for load format "
f"load format {load_config.load_format}" f"{load_config.load_format}: "
f"{unexpected_keys}"
) )
def _maybe_download_from_modelscope( def _maybe_download_from_modelscope(
...@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
self, source: "Source" self, source: "Source"
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format.""" """Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt source.model_or_path, source.revision, source.fall_back_to_pt
) )
...@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
weight_loader_disable_mmap = global_server_args_dict.get( weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap" "weight_loader_disable_mmap"
) )
weights_iterator = safetensors_weights_iterator(
hf_weights_files, disable_mmap=weight_loader_disable_mmap if extra_config.get("enable_multithread_load"):
) weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
disable_mmap=weight_loader_disable_mmap,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files, disable_mmap=weight_loader_disable_mmap
)
else: else:
weights_iterator = pt_weights_iterator(hf_weights_files) if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
# Apply the prefix. # Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
...@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config, self.load_config,
) )
self.load_weights_and_postprocess( self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device model, self._get_all_weights(model_config, model), target_device
) )
return model.eval() return model.eval()
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights.""" """Utilities for downloading and initializing model weights."""
import concurrent.futures
import fnmatch import fnmatch
import glob import glob
import hashlib import hashlib
import json import json
import logging import logging
import os import os
import queue
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import (
...@@ -453,6 +455,60 @@ def safetensors_weights_iterator( ...@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
yield name, param yield name, param
def multi_thread_safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
max_workers: int = 4,
disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if decryption_key:
logger.warning(
"Multi-Thread loading is not working for encrypted safetensor weights."
)
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(st_file: str):
if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state_dict = future.result()
for name, param in state_dict.items():
yield name, param
def pt_weights_iterator( def pt_weights_iterator(
hf_weights_files: List[str], hf_weights_files: List[str],
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
...@@ -471,6 +527,39 @@ def pt_weights_iterator( ...@@ -471,6 +527,39 @@ def pt_weights_iterator(
del state del state
def multi_thread_pt_weights_iterator(
hf_weights_files: List[str],
max_workers: int = 4,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
def _load_file(bin_file: str):
return torch.load(bin_file, map_location="cpu", weights_only=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
]
if enable_tqdm:
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
)
else:
futures_iter = concurrent.futures.as_completed(futures)
for future in futures_iter:
state = future.result()
yield from state.items()
def get_gguf_extra_tensor_names( def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str] gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> List[str]: ) -> List[str]:
......
...@@ -47,6 +47,7 @@ class ServerArgs: ...@@ -47,6 +47,7 @@ class ServerArgs:
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
model_loader_extra_config: str = "{}"
trust_remote_code: bool = False trust_remote_code: bool = False
dtype: str = "auto" dtype: str = "auto"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
...@@ -632,6 +633,13 @@ class ServerArgs: ...@@ -632,6 +633,13 @@ class ServerArgs:
"layer before loading another to make the peak memory envelope " "layer before loading another to make the peak memory envelope "
"smaller.", "smaller.",
) )
parser.add_argument(
"--model-loader-extra-config",
type=str,
help="Extra config for model loader. "
"This will be passed to the model loader corresponding to the chosen load_format.",
default=ServerArgs.model_loader_extra_config,
)
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment