loader.py 52.8 KB
Newer Older
1
# ruff: noqa: SIM117
2
import collections
3
import copy
4
import dataclasses
5
import fnmatch
6
import glob
7
8
import json
import math
9
10
import os
from abc import ABC, abstractmethod
11
from contextlib import contextmanager
12
13
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
                    Type, cast)
14

15
import gguf
16
import huggingface_hub
17
import numpy as np
18
import torch
19
from huggingface_hub import HfApi, hf_hub_download
20
from torch import nn
21
from transformers import AutoModelForCausalLM, PretrainedConfig
22
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
23

24
25
26
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
                         ModelConfig, MultiModalConfig, ParallelConfig,
                         PoolerConfig, SchedulerConfig, VllmConfig)
27
28
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
29
from vllm.envs import VLLM_USE_MODELSCOPE
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.linear import ReplicatedLinear
32
33
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
34
from vllm.model_executor.model_loader.tensorizer import (
35
    TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
36
    serialize_vllm_model, tensorizer_weights_iterator)
37
38
39
from vllm.model_executor.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
40
41
    download_safetensors_index_file_from_hf, download_weights_from_hf,
    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
42
43
44
    get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
    initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
    safetensors_weights_iterator)
45
46
from vllm.model_executor.models import (has_inner_state, supports_lora,
                                        supports_multimodal)
47
from vllm.model_executor.utils import set_weight_attrs
48
from vllm.platforms import current_platform
49
from vllm.utils import is_pin_memory_available
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

    original_device_states: Dict[str, torch.device] = {}

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(size=p.data.size(),
                                                   stride=p.data.stride(),
                                                   dtype=p.data.dtype,
                                                   layout=p.data.layout,
                                                   device="cpu",
                                                   pin_memory=pin_memory)
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched

92
93
94
95

logger = init_logger(__name__)


96
def _get_quantization_config(
97
        model_config: ModelConfig,
98
99
        load_config: LoadConfig) -> Optional[QuantizationConfig]:
    """Get the quantization config."""
100
101
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config, load_config)
102
        capability_tuple = current_platform.get_device_capability()
103

104
105
        if capability_tuple is not None:
            capability = capability_tuple.to_int()
106
107
108
109
110
111
            if capability < quant_config.get_min_capability():
                raise ValueError(
                    f"The quantization method {model_config.quantization} "
                    "is not supported for the current GPU. "
                    f"Minimum capability: {quant_config.get_min_capability()}. "
                    f"Current capability: {capability}.")
112
113
114
115
116
117
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
118
119
        return quant_config
    return None
120
121
122


def _get_model_initialization_kwargs(
123
124
125
        model_class: Type[nn.Module],
        lora_config: Optional[LoRAConfig],
        multimodal_config: Optional[MultiModalConfig],
126
127
        scheduler_config: Optional[SchedulerConfig] = None,
        pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
128
    """Get extra kwargs for model initialization."""
129
    extra_kwargs: Dict[str, Any] = {}
130
131
132

    if supports_lora(model_class):
        # lora_config=None is used to disable LoRA
133
134
135
136
137
138
139
        extra_kwargs["lora_config"] = lora_config
    elif lora_config:
        raise ValueError(
            f"Model {model_class.__name__} does not support LoRA, "
            "but LoRA is enabled. Support for this model may "
            "be added in the future. If this is important to you, "
            "please open an issue on github.")
140

141
    if supports_multimodal(model_class):
142
        assert multimodal_config is not None
143

144
        extra_kwargs["multimodal_config"] = multimodal_config
145

146
147
    if has_inner_state(model_class) and scheduler_config:
        extra_kwargs["scheduler_config"] = scheduler_config
148
149
    if pooler_config:
        extra_kwargs["pooler_config"] = pooler_config
150
151
152
    return extra_kwargs


153
def build_model(model_class: Type[nn.Module],
youkaichao's avatar
youkaichao committed
154
                vllm_config: Optional[VllmConfig],
155
                hf_config: PretrainedConfig,
156
                cache_config: Optional[CacheConfig],
157
158
                quant_config: Optional[QuantizationConfig],
                *,
159
160
                lora_config: Optional[LoRAConfig],
                multimodal_config: Optional[MultiModalConfig],
161
                scheduler_config: Optional[SchedulerConfig],
162
163
                prefix: Optional[str] = None,
                pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
164
165
    extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
                                                    multimodal_config,
166
167
                                                    scheduler_config,
                                                    pooler_config)
168
169
    if prefix:
        extra_kwargs["prefix"] = prefix
170

171
172
173
174
175
    # TODO: unify all the module initialization code
    # to only take the `VllmConfig` object as input
    from vllm.plugins import set_vllm_config
    set_vllm_config(vllm_config)

176
177
178
179
180
181
    return model_class(config=hf_config,
                       cache_config=cache_config,
                       quant_config=quant_config,
                       **extra_kwargs)


182
def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
183
    """Initialize a model with the given configurations."""
184
185
186
187
188
    model_config = vllm_config.model_config
    lora_config = vllm_config.lora_config
    scheduler_config = vllm_config.scheduler_config
    cache_config = vllm_config.cache_config
    load_config = vllm_config.load_config
189
190
191
192
    model_class, _ = get_model_architecture(model_config)

    return build_model(
        model_class,
193
        vllm_config,
194
        model_config.hf_config,
195
        cache_config=cache_config,
196
197
        quant_config=_get_quantization_config(model_config, load_config),
        lora_config=lora_config,
198
        multimodal_config=model_config.multimodal_config,
199
        scheduler_config=scheduler_config,
200
        pooler_config=model_config.pooler_config,
201
    )
202
203
204
205
206
207
208
209


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

210
211
212
213
214
    @abstractmethod
    def download_model(self, model_config: ModelConfig) -> None:
        """Download a model so that it can be immediately loaded."""
        raise NotImplementedError

215
    @abstractmethod
216
    def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
217
        """Load a model with the given configurations."""
218
        raise NotImplementedError
219
220
221
222
223


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

240
241
242
243
244
245
246
247
248
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
249

250
251
252
253
254
255
256
257
258
259
260
261
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
262
263
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
264
                    ignore_file_pattern=self.load_config.ignore_patterns,
265
                )
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
            else:
                model_path = model
            return model_path
        return None

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str],
                         fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
283
        index_file = SAFE_WEIGHTS_INDEX_NAME
284
285
286
287
288
289
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
290
291
292
293
        elif load_format == LoadFormat.MISTRAL:
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
294
295
296
297
298
299
300
301
302
303
304
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if not is_local:
305
306
307
308
309
310
311
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
312
313
314
315
316
317
318
319
320
321
322
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

323
324
325
326
327
328
329
330
        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
331
332
                    model_name_or_path, index_file,
                    self.load_config.download_dir, revision)
333
            hf_weights_files = filter_duplicate_safetensors_files(
334
                hf_weights_files, hf_folder, index_file)
335
        else:
336
337
338
339
340
341
342
343
344
345
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
346
            self, source: "Source"
347
348
349
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
350
            source.model_or_path, source.revision, source.fall_back_to_pt)
351
352
353
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
354
            weights_iterator = np_cache_weights_iterator(
355
                source.model_or_path, self.load_config.download_dir, hf_folder,
356
357
358
359
360
361
                hf_weights_files)
        elif use_safetensors:
            weights_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weights_iterator = pt_weights_iterator(hf_weights_files)

362
        if current_platform.is_tpu():
363
364
365
366
367
368
369
370
371
372
            # In PyTorch XLA, we should call `xm.mark_step` frequently so that
            # not too many ops are accumulated in the XLA program.
            import torch_xla.core.xla_model as xm

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    xm.mark_step()

            weights_iterator = _xla_weights_iterator(weights_iterator)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

        # Apply the prefix.
        return ((source.prefix + name, tensor)
                for (name, tensor) in weights_iterator)

    def _get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:

        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
                                    True))
        yield from self._get_weights_iterator(primary_weights)

        secondary_weights = cast(Iterable[DefaultModelLoader.Source],
                                 getattr(model, "secondary_weights", ()))
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)
396

397
398
399
400
401
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model,
                              model_config.revision,
                              fall_back_to_pt=True)

402
403
404
405
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

406
        target_device = torch.device(device_config.device)
407
        with set_default_torch_dtype(model_config.dtype):
408
            with target_device:
409
                model = _initialize_model(vllm_config=vllm_config)
410
411

            model.load_weights(self._get_all_weights(model_config, model))
412

413
            for _, module in model.named_modules():
414
415
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
416
417
418
419
420
421
422
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
423
424
425
426
427
428
429
430
431
432
433
434
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

435
436
437
    def download_model(self, model_config: ModelConfig) -> None:
        pass  # Nothing to download

438
439
440
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
441
442
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
443
                model = _initialize_model(vllm_config=vllm_config)
444
445
446
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
447
448
449
450
451
452
453
454
455
456
457
458

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(
                            module, torch.device(device_config.device)):
                        quant_method.process_weights_after_loading(module)
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
            self) -> Generator[Tuple[str, torch.Tensor], None, None]:
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

483
    def _load_model_serialized_cpu(
484
        self,
485
        vllm_config: VllmConfig,
486
    ) -> nn.Module:
487
        """Load a serialized model with tensorizer to the CPU.
488

489
490
491
492
        This is only necessary when the model isn't vLLM-tensorized (see
        examples/tensorize_vllm_model.py) This should still be faster than
        default HuggingFace loading, but will be slower than loading a
        vLLM-tensorized model.
493
        """
494
495
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
496
497
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
498
                model = _initialize_model(vllm_config=vllm_config)
499
500
501
502
503

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
504
        self,
505
        vllm_config: VllmConfig,
506
507
508
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

509
510
511
        Expects a vLLM-tensorized model. See the
        examples/tensorize_vllm_model.py example script
        for serializing vLLM models."""
512
513
514
515
516
517

        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
        lora_config = vllm_config.lora_config
        cache_config = vllm_config.cache_config

518
519
520
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]
521
522
                quant_config = _get_quantization_config(
                    model_config, self.load_config)
523
                extra_kwargs = _get_model_initialization_kwargs(
524
                    model_class, lora_config, model_config.multimodal_config)
525
                extra_kwargs["quant_config"] = quant_config
526
                extra_kwargs["cache_config"] = cache_config
527
528
529
530
531
532
533
534
535

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
        return model.eval()

536
537
538
539
540
541
    def download_model(self, model_config: ModelConfig) -> None:
        self.tensorizer_config.verify_with_model_config(model_config)

        with self.tensorizer_config.open_stream():
            pass

542
543
544
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
545
546
        self._verify_config(model_config, parallel_config)

547
548
549
550
551
552
        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank
            self.tensorizer_config.tensorizer_uri = \
                self.tensorizer_config.tensorizer_uri \
                    % get_tensor_model_parallel_rank()

553
        if is_vllm_tensorized(self.tensorizer_config):
554
555
            return self._load_model_serialized(vllm_config=vllm_config)
        return self._load_model_serialized_cpu(vllm_config=vllm_config)
556

557
558
559
560
561
562
563
564
565
566
    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
        )

567

568
569
570
571
572
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
573
    `examples/save_sharded_state.py` for creating a sharded checkpoint.
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
            tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
595
596
        same_storage_groups: Dict[Any, List[Tuple[
            str, torch.Tensor]]] = collections.defaultdict(list)
597
598
599
600
601
602
603
604
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

605
        result: Dict[str, torch.Tensor] = {}
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

624
625
626
627
628
629
    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
630
631
632
633
634
635
636
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
637

638
639
640
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

641
642
643
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
644
645
646
        from safetensors.torch import safe_open

        from vllm.distributed import get_tensor_model_parallel_rank
647
648
649
650

        local_model_path = self._prepare_weights(model_config.model,
                                                 model_config.revision)

651
652
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
653
                model = _initialize_model(vllm_config=vllm_config)
654
655
656
657
                for _, module in model.named_modules():
                    quant_method = getattr(module, "quant_method", None)
                    if quant_method is not None:
                        quant_method.process_weights_after_loading(module)
658
659
            rank = get_tensor_model_parallel_rank()
            pattern = os.path.join(
660
                local_model_path,
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
                self.pattern.format(rank=rank, part="*"),
            )
            filepaths = glob.glob(pattern)
            if not filepaths:
                # TODO: support un-sharded checkpoints too
                raise ValueError(
                    f"Could not find checkpoint files '{pattern}', only "
                    f"pre-sharded checkpoints are currently supported!")
            state_dict = self._filter_subtensors(model.state_dict())
            for path in filepaths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        # If loading with LoRA enabled, additional padding may
                        # be added to certain parameters. We only load into a
                        # narrowed view of the parameter data.
                        param_data = state_dict[key].data
                        param_shape = state_dict[key].shape
                        for dim, size in enumerate(tensor.shape):
                            if size < param_shape[dim]:
                                param_data = param_data.narrow(dim, 0, size)
                        if tensor.shape != param_shape:
                            logger.warning(
                                "loading tensor of shape %s into "
                                "parameter '%s' of shape %s", tensor.shape,
                                key, param_shape)
                        param_data.copy_(tensor)
                        state_dict.pop(key)
            if state_dict:
                raise ValueError(
                    f"Missing keys {tuple(state_dict)} in loaded state!")
        return model.eval()

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank
        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: Dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )


732
733
734
class BitsAndBytesModelLoader(BaseModelLoader):
    """Model loader to load model weights with BitAndBytes quantization."""

735
736
    possible_config_file_names = ["adapter_config.json"]

737
    default_target_modules = [
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
        '.fc1.',
        '.fc2.',
        '.dense.',
        '.query_key_value.',
        '.qkv_proj.',
        '.dense_h_to_4h.',
        '.dense_4h_to_h.',
        '.out_proj.',
753
754
755
756
757
758
759
760
761
762
763
    ]

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        # we don't need to quantize the whole model, only the target modules
        # that are specified in the adapter config file. If the adapter config
        # file is not provided, we will quantize the default modules.
        if (not load_config.model_loader_extra_config
                or "qlora_adapter_name_or_path"
                not in load_config.model_loader_extra_config):
764
            self.target_modules = []
765
766
767
768
769
770
771
772
773
774
            return

        qlora_adapter = load_config.model_loader_extra_config[
            "qlora_adapter_name_or_path"]

        config_file_path = self._get_config_file(qlora_adapter)

        with open(config_file_path, "r") as f:
            config = json.load(f)
            self.target_modules = config["target_modules"]
775
776
        # Save the module names without sharding.
        self.unsharded_weights_modules: List[str] = []
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823

    def _get_config_file(self, qlora_adapter: str) -> str:
        is_local = os.path.isdir(qlora_adapter)
        config_file_path = None
        if is_local:
            for file in self.possible_config_file_names:
                config_file_path = os.path.join(qlora_adapter, file)
                if os.path.exists(config_file_path):
                    break
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
            for file in self.possible_config_file_names:
                if file in repo_files:
                    config_file_path = hf_hub_download(repo_id=qlora_adapter,
                                                       filename=file)
                    break

        if not config_file_path:
            raise ValueError(
                f"Cannot find adapter config file in {qlora_adapter}")

        return config_file_path

    def _get_weight_files(
            self,
            model_name_or_path: str,
            allowed_patterns: List[str],
            revision: Optional[str] = None) -> Tuple[List[str], str]:
        """Retrieve weight files. Download the files if necessary. 
        
        Return the weight files and the file pattern."""
        is_local = os.path.isdir(model_name_or_path)

        if is_local:
            for pattern in allowed_patterns:
                weight_files = glob.glob(
                    os.path.join(model_name_or_path, pattern))
                if weight_files:
                    return weight_files, pattern
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
            for pattern in allowed_patterns:
                matching_files = fnmatch.filter(repo_files, pattern)
                if matching_files:
                    hf_folder = download_weights_from_hf(
824
825
826
827
828
829
                        model_name_or_path,
                        self.load_config.download_dir,
                        [pattern],
                        revision,
                        ignore_patterns=self.load_config.ignore_patterns,
                    )
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
                    return glob.glob(os.path.join(hf_folder, pattern)), pattern

        raise RuntimeError(
            f"No model weights found in: `{model_name_or_path}`")

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> Tuple[List[str], bool]:
        """Prepare weight files for the model."""

        allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

        hf_weights_files, matched_pattern = self._get_weight_files(
            model_name_or_path, allowed_patterns, revision)

        if matched_pattern != "*.safetensors":
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_weights_files, matched_pattern == "*.safetensors"

854
855
856
857
858
859
    def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
        if use_safetensors:
            return safetensors_weights_iterator(hf_weights_files)
        else:
            return pt_weights_iterator(hf_weights_files)

860
    def _get_quantized_weights_iterator(
861
862
863
864
865
        self,
        model_name_or_path: str,
        revision: Optional[str],
        pre_quant: bool,
        load_8bit: bool,
866
867
868
869
870
871
872
873
    ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
                                                                     Any]]:
        """Get an iterator to the model weights with bitsandbytes quantization,
        as well as the quantization state dictionary."""

        # only load the bitsandbytes module when needed
        try:
            import bitsandbytes
874
            if bitsandbytes.__version__ < "0.44.0":
875
                raise ImportError("bitsandbytes version is wrong. Please "
876
                                  "install bitsandbytes>=0.44.0.")
877
        except ImportError as err:
878
879
            raise ImportError("Please install bitsandbytes>=0.44.0 via "
                              "`pip install bitsandbytes>=0.44.0` to use "
880
881
882
883
884
                              "bitsandbytes quantizer.") from err

        hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision)

885
        quant_state_dict: Dict[str, Any] = {}
886

887
888
889
890
891
892
893
894
895
        if pre_quant:
            if load_8bit:
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
896

897
898
        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
899

900
901
902
903
904
905
906
907
908
909
910
911
912
    def _is_8bit_weight_name(self, weight_name: str):
        quantized_suffix = {".scb", ".weight_format"}
        return any(weight_name.lower().endswith(suffix)
                   for suffix in quantized_suffix)

    def _is_4bit_weight_name(self, weight_name: str):
        quantized_suffix = {
            "absmax", "quant_map", "nested_absmax", "nested_quant_map",
            "bitsandbytes"
        }
        suffix = weight_name.split(".")[-1]
        return any(q_suffix in suffix for q_suffix in quantized_suffix)

913
914
915
916
917
918
919
920
921
922
923
924
925
    def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
            if not weight_name.lower().endswith(".scb"):
                continue

            weight_key = weight_name.lower().replace(".scb", ".qweight")
            quant_state_dict[weight_key] = weight_tensor

        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):

926
            if self._is_8bit_weight_name(weight_name):
927
928
929
                continue

            qweight_name = weight_name.replace(".weight", ".qweight")
930

931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
            if qweight_name in quant_state_dict:
                set_weight_attrs(weight_tensor, {"load_in_8bit": True})
                yield qweight_name, weight_tensor
            else:
                yield weight_name, weight_tensor

    def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        from bitsandbytes.functional import QuantState

        # First iterate over all quant state weights
        weight_iterator = self._hf_weight_iter(hf_weights_files,
                                               use_safetensors)
        temp_state_dict = {}
        for weight_name, weight_tensor in weight_iterator:
946
            if not self._is_4bit_weight_name(weight_name):
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
                continue
            # bitsandbytes library requires
            # weight.quant_state.bitsandbytes__* in CPU
            if "quant_state.bitsandbytes" in weight_name:
                temp_state_dict[weight_name] = weight_tensor.cpu().data
            else:
                temp_state_dict[weight_name] = weight_tensor

        # Closure to parse quant_state for each prequant weight
        def _parse_quant_state(param_name: str,
                               temp_state_dict: Dict) -> QuantState:
            quant_state = {}
            for k in temp_state_dict:
                if param_name + "." in k:
                    quant_state[k] = temp_state_dict[k]

            return QuantState.from_dict(quant_state, device="cuda")

        # Second iterate over all prequant and normal weights
        # pre quantized weights would have a quant_state
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
969

970
            if self._is_4bit_weight_name(weight_name):
971
                continue
972

973
974
975
976
977
978
979
980
981
982
983
984
985
986
            if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
                    in temp_state_dict) or \
            (f"{weight_name}.quant_state.bitsandbytes__fp4" \
                    in temp_state_dict):
                quant_state = _parse_quant_state(weight_name, temp_state_dict)
                weight_name = weight_name.replace(".weight", ".qweight")
                quant_state_dict[weight_name] = quant_state
                yield weight_name.replace(".weight", ".qweight"), weight_tensor
            else:
                yield weight_name, weight_tensor

    def _unquantized_generator(self, hf_weights_files, use_safetensors,
                               quant_state_dict) -> Generator:
        from bitsandbytes.functional import quantize_4bit
987
988
989
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

990
991
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
992
993
994

            if any(target_module in weight_name for target_module in
                   self.target_modules) and weight_name.endswith(".weight"):
995
                weight_name = weight_name.replace(".weight", ".qweight")
996
997
998
999
1000
1001
1002
1003
                # Without sharding
                if any(
                        weight_name.startswith(module)
                        for module in self.unsharded_weights_modules):
                    weight_sub_tensor = weight_tensor
                # Shard by column
                elif any(module in weight_name
                         for module in self.column_parallel_weights_modules):
1004

1005
1006
1007
1008
1009
                    total_size = weight_tensor.size(-1)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[...,
                                                      start_index:end_index]
1010
                # Shard by row
1011
1012
1013
1014
1015
1016
1017
                else:
                    total_size = weight_tensor.size(0)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[start_index:end_index,
                                                      ...]

1018
                # bitsandbytes requires data in GPU
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
                if weight_sub_tensor.is_cuda:
                    loaded_weight = weight_sub_tensor
                else:
                    loaded_weight = weight_sub_tensor.cuda()

                # remove the following after the issue is fixed:
                # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
                if loaded_weight.is_contiguous() is False:
                    loaded_weight = loaded_weight.contiguous()

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
                with set_default_torch_dtype(torch.float32):
                    processed_weight, quant_state = quantize_4bit(
                        loaded_weight,
                        compress_statistics=True,
                        quant_type="nf4")

                quant_state_dict[weight_name] = quant_state
            else:
                processed_weight = weight_tensor

            yield weight_name, processed_weight
1040
1041
1042
1043
1044
1045

    def _load_weights(self, model_config: ModelConfig,
                      model: nn.Module) -> None:
        if not hasattr(model, 'load_weights'):
            raise AttributeError(
                "The required method 'load_weights' is not defined in class"
1046
                f" {type(model).__name__}.")
1047
1048
1049

        if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
            raise AttributeError(
1050
                f"Model {type(model).__name__} does not support BitsAndBytes "
1051
1052
                "quantization yet.")

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        if len(self.target_modules) == 0:
            if hasattr(model, 'default_bitsandbytes_target_modules'):
                self.target_modules = model.default_bitsandbytes_target_modules
            else:
                self.target_modules = self.default_target_modules

        if hasattr(model, 'column_parallel_weights_modules'):
            self.column_parallel_weights_modules = \
                model.column_parallel_weights_modules
        else:
            self.column_parallel_weights_modules = []
1064
1065
1066
1067
1068
1069
1070
1071
1072
        # Some modules like `ReplicatedLinear` should not have their weights
        # sharded. The reason for implementing it this way is to avoid new
        # static variable in the model implementation.
        # TODO: Can we reduce the static variables needed for BNB based on
        #  model information?
        self.unsharded_weights_modules = [
            name for name, module in model.named_modules()
            if isinstance(module, (ReplicatedLinear, ))
        ]
1073
1074
        self.model_type = type(model).__name__

1075
1076
1077
        logger.info("Loading weights with BitsAndBytes quantization. "
                    " May take a while ...")

1078
1079
        quant_config = getattr(model_config.hf_config, "quantization_config",
                               None)
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090

        pre_quant = False
        if quant_config is not None:
            quant_method = quant_config.get('quant_method')
            if quant_method == "bitsandbytes":
                pre_quant = True
            else:
                raise ValueError(
                    f"BitsAndBytes loader does not support {quant_method} "
                    "quantization")

1091
1092
1093
1094
1095
1096
1097
        # The quant_states in pre_quantized models cannot work with a split
        # weight tensor. So TP does not work with pre_quantized bnb models.
        if pre_quant and get_tensor_model_parallel_world_size() > 1:
            raise ValueError(
                "Prequant BitsAndBytes models with TP is not supported."
                "Please try with PP.")

1098
1099
1100
        load_8bit = False
        if pre_quant:
            load_8bit = quant_config.get('load_in_8bit', False)
1101
1102
1103

        qweight_iterator, quant_state_dict = \
            self._get_quantized_weights_iterator(
1104
            model_config.model, model_config.revision, pre_quant, load_8bit)
1105
1106
1107

        model.load_weights(qweight_iterator)

1108
1109
        torch.cuda.empty_cache()

1110
1111
1112
1113
1114
1115
1116
1117
1118
        param_dict = dict(model.named_parameters())
        stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
        for quant_param_name in quant_state_dict:
            non_stacked_param_name = quant_param_name

            shard_index = 0
            for shard_name, (
                    weight_name, index
            ) in model.bitsandbytes_stacked_params_mapping.items():
1119
1120
1121
1122
1123
1124
1125

                shard_pos = quant_param_name.find(shard_name)
                # Some models, such as MiniCPM V2.5/2.6, contain both
                # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
                # from being incorrectly identified as being present in
                # 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
                if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
                    shard_index = index
                    quant_param_name = quant_param_name.replace(
                        shard_name, weight_name)
                    break

            if quant_param_name not in param_dict:
                raise ValueError(
                    f"Parameter {quant_param_name} not found in the model.")

            if quant_param_name not in stacked_quant_state_dict:
                stacked_quant_state_dict[quant_param_name] = {}

            stacked_quant_state_dict[quant_param_name][shard_index] = (
                quant_state_dict[non_stacked_param_name])

        # save quant_states and offsets as the attributes of the parameters
        for param_name, param in param_dict.items():
            if param_name in stacked_quant_state_dict:
                quant_states = stacked_quant_state_dict[param_name]
                set_weight_attrs(param, {"bnb_quant_state": quant_states})

                pack_ratio = getattr(param, "pack_factor", -1)
                if pack_ratio == -1:
                    raise ValueError(
                        f"pack_factor not set for parameter {param_name}.")

                num_elements = [0] * len(quant_states)
1153
                for seq, quant_state in quant_states.items():
1154
                    num_elements[seq] = math.prod(
1155
                        quant_state.shape) // pack_ratio
1156
1157
1158
1159

                offsets = np.concatenate(([0], np.cumsum(num_elements)))
                set_weight_attrs(param, {"bnb_shard_offsets": offsets})

1160
1161
1162
1163
                if load_8bit:
                    set_weight_attrs(
                        param, {"matmul_state": [None] * len(quant_states)})

1164
1165
1166
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

1167
1168
1169
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1170
1171
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1172
                model = _initialize_model(vllm_config=vllm_config)
1173
1174
1175
1176
1177
1178

                self._load_weights(model_config, model)

        return model.eval()


1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _prepare_weights(self, model_name_or_path: str):
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        else:
            raise ValueError(f"{model_name_or_path} is not a file.")

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        model_type = config.model_type
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        num_layers = config.num_hidden_layers
        name_map = gguf.get_tensor_name_map(arch, num_layers)
        with torch.device("meta"):
            dummy_model = AutoModelForCausalLM.from_config(config)
        state_dict = dummy_model.state_dict()

        gguf_to_hf_name_map = {}
        for hf_name in state_dict:
            name, suffix = hf_name.rsplit(".", 1)
            gguf_name = name_map.get_name(name)
            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
        return gguf_to_hf_name_map

    def _get_weights_iterator(
        self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        return gguf_quant_weights_iterator(model_name_or_path,
                                           gguf_to_hf_name_map)

1238
1239
1240
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model)

1241
1242
1243
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1244
1245
1246
1247
1248
1249
1250
1251
1252
        local_model_path = self._prepare_weights(model_config.model)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        if "lm_head.weight" in get_gguf_extra_tensor_names(
                local_model_path, gguf_weights_map):
            model_config.hf_config.update({"tie_word_embeddings": True})

        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1253
                model = _initialize_model(vllm_config=vllm_config)
1254
1255
1256
1257
1258
            model.load_weights(
                self._get_weights_iterator(local_model_path, gguf_weights_map))
        return model


1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

1271
1272
1273
    if load_config.load_format == LoadFormat.SHARDED_STATE:
        return ShardedStateLoader(load_config)

1274
1275
1276
    if load_config.load_format == LoadFormat.BITSANDBYTES:
        return BitsAndBytesModelLoader(load_config)

1277
1278
1279
    if load_config.load_format == LoadFormat.GGUF:
        return GGUFModelLoader(load_config)

1280
    return DefaultModelLoader(load_config)