loader.py 16.5 KB
Newer Older
1
2
3
4
5
# ruff: noqa: SIM117
import copy
import glob
import os
from abc import ABC, abstractmethod
6
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
7

8
import huggingface_hub
9
10
11
import torch
from torch import nn

12
13
14
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, VisionLanguageConfig)
15
from vllm.envs import VLLM_USE_MODELSCOPE
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from vllm.model_executor.model_loader.tensorizer import (
    TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
    tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
    download_weights_from_hf, filter_files_not_needed_for_inference,
    get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
    pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration

_VISION_MODEL_CLASSES = [
    LlavaForConditionalGeneration,
]

logger = init_logger(__name__)


37
def _get_quantization_config(
38
        model_config: ModelConfig,
39
40
        load_config: LoadConfig) -> Optional[QuantizationConfig]:
    """Get the quantization config."""
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config, load_config)
        capability = torch.cuda.get_device_capability()
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
57
58
        return quant_config
    return None
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


def _get_model_initialization_kwargs(
        model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig]
) -> Dict[str, Any]:
    """Get extra kwargs for model initialization."""
    extra_kwargs = {}
    if hasattr(model_class, "supported_lora_modules"):
        extra_kwargs["lora_config"] = lora_config
    elif lora_config:
        raise ValueError(
            f"Model {model_class.__name__} does not support LoRA, "
            "but LoRA is enabled. Support for this model may "
            "be added in the future. If this is important to you, "
            "please open an issue on github.")
    elif model_class in _VISION_MODEL_CLASSES:
        extra_kwargs["vision_language_config"] = vision_language_config
    return extra_kwargs


80
81
82
83
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
                      lora_config: Optional[LoRAConfig],
                      vision_language_config: Optional[VisionLanguageConfig],
                      cache_config: CacheConfig) -> nn.Module:
84
85
    """Initialize a model with the given configurations."""
    model_class = get_model_architecture(model_config)[0]
86
    quant_config = _get_quantization_config(model_config, load_config)
87
88

    return model_class(config=model_config.hf_config,
89
                       cache_config=cache_config,
90
                       quant_config=quant_config,
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
                       **_get_model_initialization_kwargs(
                           model_class, lora_config, vision_language_config))


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

    @abstractmethod
    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
                   vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig,
107
108
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        """Load a model with the given configurations."""
        ...


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
        
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
138
139
140
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
                )
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            else:
                model_path = model
            return model_path
        return None

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str],
                         fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if not is_local:
            hf_folder = download_weights_from_hf(model_name_or_path,
                                                 self.load_config.download_dir,
177
                                                 allow_patterns, revision)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

        if not use_safetensors:
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
        self, model_name_or_path: str, revision: Optional[str],
        fall_back_to_pt: bool
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision, fall_back_to_pt)
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
            return np_cache_weights_iterator(model_name_or_path,
                                             self.load_config.download_dir,
                                             hf_folder, hf_weights_files)
        if use_safetensors:
            return safetensors_weights_iterator(hf_weights_files)
        return pt_weights_iterator(hf_weights_files)

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
                   vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig,
221
222
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
223
224
225
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
226
227
                                          lora_config, vision_language_config,
                                          cache_config)
228
229
230
231
232
233
234
            model.load_weights(
                self._get_weights_iterator(model_config.model,
                                           model_config.revision,
                                           fall_back_to_pt=getattr(
                                               model,
                                               "fall_back_to_pt_during_load",
                                               True)), )
235
            for _, module in model.named_modules():
236
237
238
239
240
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)
                # FIXME: Remove this after Mixtral is updated
                # to use quant_method.
241
242
                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
                   vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig,
260
261
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
262
263
264
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
265
266
                                          lora_config, vision_language_config,
                                          cache_config)
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
            self) -> Generator[Tuple[str, torch.Tensor], None, None]:
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

    def _load_model_unserialized(
295
296
297
298
299
300
        self,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig],
        cache_config: CacheConfig,
301
302
303
304
305
306
307
308
309
310
    ) -> nn.Module:
        """Load an unserialized model with tensorizer.

        Unserialized here means "not serialized with tensorizer". This
        should still be faster than default HuggingFace loading, but will
        be slower than loading a tensorizer-serialized model.
        """
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
311
312
                                          lora_config, vision_language_config,
                                          cache_config)
313
314
315
316
317

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
318
319
320
321
322
323
        self,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig],
        cache_config: CacheConfig,
324
325
326
327
328
329
330
331
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

        See the examples/tensorize_vllm_model.py example "
        script for serializing vLLM models."""
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]
332
333
                quant_config = _get_quantization_config(
                    model_config, self.load_config)
334
335
                extra_kwargs = _get_model_initialization_kwargs(
                    model_class, lora_config, vision_language_config)
336
                extra_kwargs["quant_config"] = quant_config
337
                extra_kwargs["cache_config"] = cache_config
338
339
340
341
342
343
344
345
346
347
348
349
350
351

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
        return model.eval()

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
                   vision_language_config: Optional[VisionLanguageConfig],
                   parallel_config: ParallelConfig,
352
353
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
354
355
356
357
358
        self._verify_config(model_config, parallel_config)

        if is_vllm_serialized_tensorizer(self.tensorizer_config):
            return self._load_model_serialized(model_config, device_config,
                                               lora_config,
359
360
                                               vision_language_config,
                                               cache_config)
361
362
        return self._load_model_unserialized(model_config, device_config,
                                             lora_config,
363
364
                                             vision_language_config,
                                             cache_config)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

    return DefaultModelLoader(load_config)