"vllm/vscode:/vscode.git/clone" did not exist on "26e1188e51aca3b76184671d804a8b17c294b610"
load.py 6.18 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from typing import TYPE_CHECKING, Any
5

6
from pydantic import Field, field_validator
7
8
9

from vllm.config.utils import config
from vllm.logger import init_logger
10
from vllm.utils.hashing import safe_hash
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

if TYPE_CHECKING:
    from vllm.model_executor.model_loader import LoadFormats
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
else:
    LoadFormats = Any
    TensorizerConfig = Any

logger = init_logger(__name__)


@config
class LoadConfig:
    """Configuration for loading the model weights."""

26
    load_format: str | LoadFormats = "auto"
27
28
29
    """
    The format of the model weights to load.

30
    - "auto" will try to load the weights in the safetensors format and fall
31
32
33
      back to the pytorch bin format if safetensors format is not available.
    - "pt" will load the weights in the pytorch bin format.
    - "safetensors" will load the weights in the safetensors format.
34
    - "instanttensor" will load the Safetensors weights on CUDA devices using
35
36
      InstantTensor, which enables distributed loading with pipelined prefetching
      and fast direct I/O.
37
    - "npcache" will load the weights in pytorch format and store a numpy cache
38
      to speed up the loading.
39
    - "dummy" will initialize the weights with random values, which is mainly
40
      for profiling.
41
    - "tensorizer" will use CoreWeave's tensorizer library for fast weight
42
43
      loading. See the Tensorize vLLM Model script in the Examples section for
      more information.
44
    - "runai_streamer" will load the Safetensors weights using Run:ai Model
45
      Streamer.
46
    - "runai_streamer_sharded" will load weights from pre-sharded checkpoint
47
48
      files using Run:ai Model Streamer.
    - "bitsandbytes" will load the weights using bitsandbytes quantization.
49
    - "sharded_state" will load weights from pre-sharded checkpoint files,
50
      supporting efficient loading of tensor-parallel models.
51
    - "gguf" will load weights from GGUF format files (details specified in
52
      https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).
53
    - "mistral" will load weights from consolidated safetensors files used by
54
55
56
      Mistral models.\n
    - Other custom values can be supported via plugins.
    """
57
    download_dir: str | None = None
58
59
    """Directory to download and load the weights, default to the default
    cache directory of Hugging Face."""
60
    safetensors_load_strategy: str | None = None
61
62
63
    """
    Specifies the loading strategy for safetensors weights.

64
65
66
67
    - None (default): Uses memory-mapped (lazy) loading. When an NFS
      filesystem is detected and the total checkpoint size fits within 90%%
      of available RAM, prefetching is enabled automatically.
    - "lazy": Weights are memory-mapped from the file. This enables
68
      on-demand loading and is highly efficient for models on local storage.
69
      Unlike the default (None), auto-prefetch on NFS is not performed.
70
71
72
73
    - "eager": The entire file is read into CPU memory upfront before loading.
      This is recommended for models on network filesystems (e.g., Lustre, NFS)
      as it avoids inefficient random reads, significantly speeding up model
      initialization. However, it uses more CPU RAM.
74
75
76
    - "prefetch": Checkpoint files are read into the OS page cache before
      workers load them, speeding up the model loading phase. Useful on
      network or high-latency storage.
77
78
79
    - "torchao": Weights are loaded in upfront and then reconstructed
      into torchao tensor subclasses. This is used when the checkpoint
      was quantized using torchao and saved using safetensors.
80
      Needs `torchao >= 0.14.0`.
81
    """
82
    model_loader_extra_config: dict | TensorizerConfig = Field(default_factory=dict)
83
84
    """Extra config for model loader. This will be passed to the model loader
    corresponding to the chosen load_format."""
85
    device: str | None = None
86
87
    """Device to which model weights will be loaded, default to
    device_config.device"""
88
    ignore_patterns: list[str] | str = Field(default_factory=lambda: ["original/**/*"])
89
90
91
92
93
    """The list of patterns to ignore when loading the model. Default to
    "original/**/*" to avoid repeated loading of llama's checkpoints."""
    use_tqdm_on_load: bool = True
    """Whether to enable tqdm for showing progress bar when loading model
    weights."""
94
    pt_load_map_location: str | dict[str, str] = "cpu"
95
    """
96
97
98
99
100
101
102
    The map location for loading pytorch checkpoint, to support loading
    checkpoints can only be loaded on certain devices like "cuda", this
    is equivalent to `{"": "cuda"}`. Another supported format is mapping
    from different devices like from GPU 1 to GPU 0: `{"cuda:1": "cuda:0"}`.
    Note that when passed from command line, the strings in dictionary
    need to be double quoted for json parsing. For more details, see
    the original doc for `map_location` parameter in [`torch.load`][] parameter.
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    """

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
120
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
121
122
        return hash_str

123
124
125
126
127
128
    @field_validator("load_format", mode="after")
    def _lowercase_load_format(cls, load_format: str) -> str:
        return load_format.lower()

    @field_validator("ignore_patterns", mode="after")
    def _validate_ignore_patterns(
129
130
        cls, ignore_patterns: list[str] | str
    ) -> list[str] | str:
131
        if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
132
133
            logger.info(
                "Ignoring the following patterns when downloading weights: %s",
134
                ignore_patterns,
135
            )
136
137

        return ignore_patterns