loader.py 52.6 KB
Newer Older
1
# ruff: noqa: SIM117
2
import collections
3
import copy
4
import dataclasses
5
import fnmatch
6
import glob
7
import inspect
8
import itertools
9
10
import json
import math
11
12
import os
from abc import ABC, abstractmethod
13
from contextlib import contextmanager
14
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
15

16
import gguf
17
import huggingface_hub
18
import numpy as np
19
import torch
20
from huggingface_hub import HfApi, hf_hub_download
21
from torch import nn
22
from transformers import AutoModelForCausalLM
23
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24

25
26
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
                         VllmConfig)
27
28
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
29
from vllm.envs import VLLM_USE_MODELSCOPE
30
from vllm.logger import init_logger
31
32
33
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
34
                                               RowParallelLinear)
35
36
from vllm.model_executor.layers.quantization.base_config import (
    QuantizeMethodBase)
37
from vllm.model_executor.model_loader.tensorizer import (
38
    TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
39
    serialize_vllm_model, tensorizer_weights_iterator)
40
41
42
from vllm.model_executor.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
43
44
    download_safetensors_index_file_from_hf, download_weights_from_hf,
    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
45
    get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
46
47
    initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
    safetensors_weights_iterator)
48
from vllm.model_executor.utils import set_weight_attrs
49
from vllm.platforms import current_platform
50
from vllm.plugins import set_current_vllm_config
51
from vllm.utils import is_pin_memory_available
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

    original_device_states: Dict[str, torch.device] = {}

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(size=p.data.size(),
                                                   stride=p.data.stride(),
                                                   dtype=p.data.dtype,
                                                   layout=p.data.layout,
                                                   device="cpu",
                                                   pin_memory=pin_memory)
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched

94
95
96
97

logger = init_logger(__name__)


98
def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
99
    """Initialize a model with the given configurations."""
100
    model_config = vllm_config.model_config
101
    model_class, _ = get_model_architecture(model_config)
102
    signatures = inspect.signature(model_class.__init__)
103
104
105
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
106
107
        with set_current_vllm_config(vllm_config):
            return model_class(vllm_config=vllm_config, prefix=prefix)
108
109
110
    msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
           "input arguments. Possibly you have an old-style model class"
           " registered from out of tree and it is used for new vLLM version. "
111
           "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
           "for the design and update the model class accordingly.")
    logger.warning(msg)
    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
        model_class)
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
131
132
    with set_current_vllm_config(vllm_config):
        return model_class(**kwargs)
133
134
135
136
137
138
139
140


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

141
142
143
144
145
    @abstractmethod
    def download_model(self, model_config: ModelConfig) -> None:
        """Download a model so that it can be immediately loaded."""
        raise NotImplementedError

146
    @abstractmethod
147
    def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
148
        """Load a model with the given configurations."""
149
        raise NotImplementedError
150
151
152
153
154


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

171
172
173
174
175
176
177
178
179
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
180

181
182
183
184
185
186
187
188
189
190
191
192
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
193
194
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
195
                    ignore_file_pattern=self.load_config.ignore_patterns,
196
                )
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            else:
                model_path = model
            return model_path
        return None

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str],
                         fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
214
        index_file = SAFE_WEIGHTS_INDEX_NAME
215
216
217
218
219
220
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
221
222
223
224
        elif load_format == LoadFormat.MISTRAL:
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
225
226
227
228
229
230
231
232
233
234
235
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if not is_local:
236
237
238
239
240
241
242
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
243
244
245
246
247
248
249
250
251
252
253
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

254
255
256
257
258
259
260
261
        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
262
263
                    model_name_or_path, index_file,
                    self.load_config.download_dir, revision)
264
            hf_weights_files = filter_duplicate_safetensors_files(
265
                hf_weights_files, hf_folder, index_file)
266
        else:
267
268
269
270
271
272
273
274
275
276
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
277
            self, source: "Source"
278
279
280
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
281
            source.model_or_path, source.revision, source.fall_back_to_pt)
282
283
284
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
285
            weights_iterator = np_cache_weights_iterator(
286
                source.model_or_path, self.load_config.download_dir, hf_folder,
287
288
289
290
291
292
                hf_weights_files)
        elif use_safetensors:
            weights_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weights_iterator = pt_weights_iterator(hf_weights_files)

293
        if current_platform.is_tpu():
294
295
296
297
298
299
300
301
302
303
            # In PyTorch XLA, we should call `xm.mark_step` frequently so that
            # not too many ops are accumulated in the XLA program.
            import torch_xla.core.xla_model as xm

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    xm.mark_step()

            weights_iterator = _xla_weights_iterator(weights_iterator)
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        # Apply the prefix.
        return ((source.prefix + name, tensor)
                for (name, tensor) in weights_iterator)

    def _get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:

        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
                                    True))
        yield from self._get_weights_iterator(primary_weights)

        secondary_weights = cast(Iterable[DefaultModelLoader.Source],
                                 getattr(model, "secondary_weights", ()))
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)
327

328
329
330
331
332
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model,
                              model_config.revision,
                              fall_back_to_pt=True)

333
334
335
336
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

337
        target_device = torch.device(device_config.device)
338
        with set_default_torch_dtype(model_config.dtype):
339
            with target_device:
340
                model = _initialize_model(vllm_config=vllm_config)
341

342
343
344
345
346
347
348
349
350
351
352
            weights_to_load = {name for name, _ in model.named_parameters()}
            loaded_weights = model.load_weights(
                self._get_all_weights(model_config, model))
            # We only enable strict check for non-quantiized models
            # that have loaded weights tracking currently.
            if model_config.quantization is None and loaded_weights is not None:
                weights_not_loaded = weights_to_load - loaded_weights
                if weights_not_loaded:
                    raise ValueError(
                        "Following weights were not initialized from "
                        f"checkpoint: {weights_not_loaded}")
353

354
            for _, module in model.named_modules():
355
                quant_method = getattr(module, "quant_method", None)
356
                if isinstance(quant_method, QuantizeMethodBase):
357
358
359
360
361
362
363
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
364
365
366
367
368
369
370
371
372
373
374
375
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

376
377
378
    def download_model(self, model_config: ModelConfig) -> None:
        pass  # Nothing to download

379
380
381
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
382
383
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
384
                model = _initialize_model(vllm_config=vllm_config)
385
386
387
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
388
389
390
391
392
393
394
395
396
397
398
399

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(
                            module, torch.device(device_config.device)):
                        quant_method.process_weights_after_loading(module)
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
            self) -> Generator[Tuple[str, torch.Tensor], None, None]:
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

424
    def _load_model_serialized_cpu(
425
        self,
426
        vllm_config: VllmConfig,
427
    ) -> nn.Module:
428
        """Load a serialized model with tensorizer to the CPU.
429

430
431
432
433
        This is only necessary when the model isn't vLLM-tensorized (see
        examples/tensorize_vllm_model.py) This should still be faster than
        default HuggingFace loading, but will be slower than loading a
        vLLM-tensorized model.
434
        """
435
436
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
437
438
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
439
                model = _initialize_model(vllm_config=vllm_config)
440
441
442
443
444

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
445
        self,
446
        vllm_config: VllmConfig,
447
448
449
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

450
451
452
        Expects a vLLM-tensorized model. See the
        examples/tensorize_vllm_model.py example script
        for serializing vLLM models."""
453
454
455
456

        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

457
458
459
460
461
462
463
464
465
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

466
467
                model = load_with_tensorizer(tensorizer_config,
                                             vllm_config=vllm_config)
468
469
        return model.eval()

470
471
472
473
474
475
    def download_model(self, model_config: ModelConfig) -> None:
        self.tensorizer_config.verify_with_model_config(model_config)

        with self.tensorizer_config.open_stream():
            pass

476
477
478
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
479
480
        self._verify_config(model_config, parallel_config)

481
482
483
484
485
486
        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank
            self.tensorizer_config.tensorizer_uri = \
                self.tensorizer_config.tensorizer_uri \
                    % get_tensor_model_parallel_rank()

487
        if is_vllm_tensorized(self.tensorizer_config):
488
489
            return self._load_model_serialized(vllm_config=vllm_config)
        return self._load_model_serialized_cpu(vllm_config=vllm_config)
490

491
492
493
494
495
496
497
498
499
500
    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
        )

501

502
503
504
505
506
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
507
    `examples/save_sharded_state.py` for creating a sharded checkpoint.
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
            tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
529
530
        same_storage_groups: Dict[Any, List[Tuple[
            str, torch.Tensor]]] = collections.defaultdict(list)
531
532
533
534
535
536
537
538
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

539
        result: Dict[str, torch.Tensor] = {}
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

558
559
560
561
562
563
    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
564
565
566
567
568
569
570
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
571

572
573
574
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

575
576
577
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
578
579
580
        from safetensors.torch import safe_open

        from vllm.distributed import get_tensor_model_parallel_rank
581
582
583
584

        local_model_path = self._prepare_weights(model_config.model,
                                                 model_config.revision)

585
586
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
587
                model = _initialize_model(vllm_config=vllm_config)
588
589
590
591
                for _, module in model.named_modules():
                    quant_method = getattr(module, "quant_method", None)
                    if quant_method is not None:
                        quant_method.process_weights_after_loading(module)
592
593
            rank = get_tensor_model_parallel_rank()
            pattern = os.path.join(
594
                local_model_path,
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
                self.pattern.format(rank=rank, part="*"),
            )
            filepaths = glob.glob(pattern)
            if not filepaths:
                # TODO: support un-sharded checkpoints too
                raise ValueError(
                    f"Could not find checkpoint files '{pattern}', only "
                    f"pre-sharded checkpoints are currently supported!")
            state_dict = self._filter_subtensors(model.state_dict())
            for path in filepaths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        # If loading with LoRA enabled, additional padding may
                        # be added to certain parameters. We only load into a
                        # narrowed view of the parameter data.
                        param_data = state_dict[key].data
                        param_shape = state_dict[key].shape
                        for dim, size in enumerate(tensor.shape):
                            if size < param_shape[dim]:
                                param_data = param_data.narrow(dim, 0, size)
                        if tensor.shape != param_shape:
                            logger.warning(
                                "loading tensor of shape %s into "
                                "parameter '%s' of shape %s", tensor.shape,
                                key, param_shape)
                        param_data.copy_(tensor)
                        state_dict.pop(key)
            if state_dict:
                raise ValueError(
                    f"Missing keys {tuple(state_dict)} in loaded state!")
        return model.eval()

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank
        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: Dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )


666
667
668
class BitsAndBytesModelLoader(BaseModelLoader):
    """Model loader to load model weights with BitAndBytes quantization."""

669
670
    possible_config_file_names = ["adapter_config.json"]

671
    default_target_modules = [
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
        '.fc1.',
        '.fc2.',
        '.dense.',
        '.query_key_value.',
        '.qkv_proj.',
        '.dense_h_to_4h.',
        '.dense_4h_to_h.',
        '.out_proj.',
687
688
689
690
691
    ]

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

692
693
694
695
        # Save the module names without sharding.
        self.unsharded_weights_modules: List[str] = []
        # Save the module names that are sharded by column.
        self.column_sharded_weights_modules: List[str] = []
696
697
698
699
700
701
        # we don't need to quantize the whole model, only the target modules
        # that are specified in the adapter config file. If the adapter config
        # file is not provided, we will quantize the default modules.
        if (not load_config.model_loader_extra_config
                or "qlora_adapter_name_or_path"
                not in load_config.model_loader_extra_config):
702
            self.target_modules = []
703
704
705
706
707
708
709
            return

        qlora_adapter = load_config.model_loader_extra_config[
            "qlora_adapter_name_or_path"]

        config_file_path = self._get_config_file(qlora_adapter)

710
        with open(config_file_path) as f:
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
            config = json.load(f)
            self.target_modules = config["target_modules"]

    def _get_config_file(self, qlora_adapter: str) -> str:
        is_local = os.path.isdir(qlora_adapter)
        config_file_path = None
        if is_local:
            for file in self.possible_config_file_names:
                config_file_path = os.path.join(qlora_adapter, file)
                if os.path.exists(config_file_path):
                    break
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
            for file in self.possible_config_file_names:
                if file in repo_files:
                    config_file_path = hf_hub_download(repo_id=qlora_adapter,
                                                       filename=file)
                    break

        if not config_file_path:
            raise ValueError(
                f"Cannot find adapter config file in {qlora_adapter}")

        return config_file_path

    def _get_weight_files(
            self,
            model_name_or_path: str,
            allowed_patterns: List[str],
            revision: Optional[str] = None) -> Tuple[List[str], str]:
        """Retrieve weight files. Download the files if necessary. 
        
        Return the weight files and the file pattern."""
        is_local = os.path.isdir(model_name_or_path)

        if is_local:
            for pattern in allowed_patterns:
                weight_files = glob.glob(
                    os.path.join(model_name_or_path, pattern))
                if weight_files:
                    return weight_files, pattern
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
            for pattern in allowed_patterns:
                matching_files = fnmatch.filter(repo_files, pattern)
                if matching_files:
                    hf_folder = download_weights_from_hf(
760
761
762
763
764
765
                        model_name_or_path,
                        self.load_config.download_dir,
                        [pattern],
                        revision,
                        ignore_patterns=self.load_config.ignore_patterns,
                    )
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
                    return glob.glob(os.path.join(hf_folder, pattern)), pattern

        raise RuntimeError(
            f"No model weights found in: `{model_name_or_path}`")

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> Tuple[List[str], bool]:
        """Prepare weight files for the model."""

        allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

        hf_weights_files, matched_pattern = self._get_weight_files(
            model_name_or_path, allowed_patterns, revision)

        if matched_pattern != "*.safetensors":
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_weights_files, matched_pattern == "*.safetensors"

790
791
792
793
794
795
    def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
        if use_safetensors:
            return safetensors_weights_iterator(hf_weights_files)
        else:
            return pt_weights_iterator(hf_weights_files)

796
    def _get_quantized_weights_iterator(
797
798
799
800
801
        self,
        model_name_or_path: str,
        revision: Optional[str],
        pre_quant: bool,
        load_8bit: bool,
802
803
804
805
806
807
808
809
    ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
                                                                     Any]]:
        """Get an iterator to the model weights with bitsandbytes quantization,
        as well as the quantization state dictionary."""

        # only load the bitsandbytes module when needed
        try:
            import bitsandbytes
810
            if bitsandbytes.__version__ < "0.44.0":
811
                raise ImportError("bitsandbytes version is wrong. Please "
812
                                  "install bitsandbytes>=0.44.0.")
813
        except ImportError as err:
814
815
            raise ImportError("Please install bitsandbytes>=0.44.0 via "
                              "`pip install bitsandbytes>=0.44.0` to use "
816
817
818
819
820
                              "bitsandbytes quantizer.") from err

        hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision)

821
        quant_state_dict: Dict[str, Any] = {}
822

823
824
825
826
827
828
829
830
831
        if pre_quant:
            if load_8bit:
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
832

833
834
        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
835

836
837
838
839
840
841
842
843
844
845
846
847
848
    def _is_8bit_weight_name(self, weight_name: str):
        quantized_suffix = {".scb", ".weight_format"}
        return any(weight_name.lower().endswith(suffix)
                   for suffix in quantized_suffix)

    def _is_4bit_weight_name(self, weight_name: str):
        quantized_suffix = {
            "absmax", "quant_map", "nested_absmax", "nested_quant_map",
            "bitsandbytes"
        }
        suffix = weight_name.split(".")[-1]
        return any(q_suffix in suffix for q_suffix in quantized_suffix)

849
850
851
852
853
854
855
    def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
            if not weight_name.lower().endswith(".scb"):
                continue

856
            weight_key = weight_name.lower().replace(".scb", ".weight")
857
858
859
860
861
            quant_state_dict[weight_key] = weight_tensor

        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):

862
            if self._is_8bit_weight_name(weight_name):
863
864
                continue

865
            if weight_name in quant_state_dict:
866
                set_weight_attrs(weight_tensor, {"load_in_8bit": True})
867
                yield weight_name, weight_tensor
868
869
870
871
872
873
874
875
876
877
878
879
            else:
                yield weight_name, weight_tensor

    def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        from bitsandbytes.functional import QuantState

        # First iterate over all quant state weights
        weight_iterator = self._hf_weight_iter(hf_weights_files,
                                               use_safetensors)
        temp_state_dict = {}
        for weight_name, weight_tensor in weight_iterator:
880
            if not self._is_4bit_weight_name(weight_name):
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
                continue
            # bitsandbytes library requires
            # weight.quant_state.bitsandbytes__* in CPU
            if "quant_state.bitsandbytes" in weight_name:
                temp_state_dict[weight_name] = weight_tensor.cpu().data
            else:
                temp_state_dict[weight_name] = weight_tensor

        # Closure to parse quant_state for each prequant weight
        def _parse_quant_state(param_name: str,
                               temp_state_dict: Dict) -> QuantState:
            quant_state = {}
            for k in temp_state_dict:
                if param_name + "." in k:
                    quant_state[k] = temp_state_dict[k]

            return QuantState.from_dict(quant_state, device="cuda")

        # Second iterate over all prequant and normal weights
        # pre quantized weights would have a quant_state
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
903

904
            if self._is_4bit_weight_name(weight_name):
905
                continue
906

907
908
909
910
911
912
            if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
                    in temp_state_dict) or \
            (f"{weight_name}.quant_state.bitsandbytes__fp4" \
                    in temp_state_dict):
                quant_state = _parse_quant_state(weight_name, temp_state_dict)
                quant_state_dict[weight_name] = quant_state
913
                yield weight_name, weight_tensor
914
915
916
917
918
919
            else:
                yield weight_name, weight_tensor

    def _unquantized_generator(self, hf_weights_files, use_safetensors,
                               quant_state_dict) -> Generator:
        from bitsandbytes.functional import quantize_4bit
920
921
922
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

923
924
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
925
926
927

            if any(target_module in weight_name for target_module in
                   self.target_modules) and weight_name.endswith(".weight"):
928
929
930
931
932
933
                # Without sharding
                if any(
                        weight_name.startswith(module)
                        for module in self.unsharded_weights_modules):
                    weight_sub_tensor = weight_tensor
                # Shard by column
934
935
936
                elif any(
                        weight_name.startswith(module)
                        for module in self.column_sharded_weights_modules):
937
938
939
940
941
                    total_size = weight_tensor.size(-1)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[...,
                                                      start_index:end_index]
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
                # Weights have fused on disk. In this case, we assume that the
                # weight and module use same name.
                elif any(
                        weight_name.startswith(module)
                        for module in self.maybe_fused_weights_modules):
                    # special case for fused weights
                    # get the size of each shard weight tensor
                    total_shard_sizes = next(
                        (sizes for module, sizes in
                         self.maybe_fused_weights_modules.items()
                         if weight_name.startswith(module)))
                    total_size = weight_tensor.size(0)
                    assert total_size == sum(total_shard_sizes)
                    # get the start/end index of each shard weight tensor
                    total_start_index = list(
                        itertools.accumulate([0] + total_shard_sizes))[:-1]
                    shard_weights_index = [
                        (idx + size // tp_size * tp_rank,
                         idx + size // tp_size * (tp_rank + 1))
                        for idx, size in zip(total_start_index,
                                             total_shard_sizes)
                    ]
                    # slice and reorder the weight tensor
                    weight_tensor = [
                        weight_tensor[start_index:end_index, ...]
                        for start_index, end_index in shard_weights_index
                    ]
                    weight_sub_tensor = torch.cat(weight_tensor, dim=0)
970
                # Shard by row
971
972
973
974
975
976
977
                else:
                    total_size = weight_tensor.size(0)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[start_index:end_index,
                                                      ...]

978
                # bitsandbytes requires data in GPU
979
980
981
982
983
984
985
986
987
988
                if weight_sub_tensor.is_cuda:
                    loaded_weight = weight_sub_tensor
                else:
                    loaded_weight = weight_sub_tensor.cuda()

                # remove the following after the issue is fixed:
                # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
                if loaded_weight.is_contiguous() is False:
                    loaded_weight = loaded_weight.contiguous()

989
990
991
992
993
994
995
996
997
998
999
                with set_default_torch_dtype(torch.float32):
                    processed_weight, quant_state = quantize_4bit(
                        loaded_weight,
                        compress_statistics=True,
                        quant_type="nf4")

                quant_state_dict[weight_name] = quant_state
            else:
                processed_weight = weight_tensor

            yield weight_name, processed_weight
1000
1001
1002
1003
1004
1005

    def _load_weights(self, model_config: ModelConfig,
                      model: nn.Module) -> None:
        if not hasattr(model, 'load_weights'):
            raise AttributeError(
                "The required method 'load_weights' is not defined in class"
1006
                f" {type(model).__name__}.")
1007
1008
1009

        if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
            raise AttributeError(
1010
                f"Model {type(model).__name__} does not support BitsAndBytes "
1011
1012
                "quantization yet.")

1013
1014
1015
1016
1017
1018
        if len(self.target_modules) == 0:
            if hasattr(model, 'default_bitsandbytes_target_modules'):
                self.target_modules = model.default_bitsandbytes_target_modules
            else:
                self.target_modules = self.default_target_modules

1019
1020
1021
1022
        # Modules whose weights might have fused on disk
        # we need their output_sizes to make shard in flight correctly with TP
        self.maybe_fused_weights_modules: Dict[str, List[int]] = {}

1023
1024
1025
1026
1027
1028
        for name, module in model.named_modules():
            # Some modules like `ReplicatedLinear` should not have their weights
            # sharded. The reason for implementing it this way is to avoid new
            # static variable in the model implementation.
            if isinstance(module, (ReplicatedLinear, )):
                self.unsharded_weights_modules.append(name)
1029
1030
1031
1032
1033
1034
            # `QKVParallelLinear` and `MergedColumnParallelLinear` might have
            # fused weights on disk. We need to use the output sizes of these
            # modules to shard the weights correctly.
            elif isinstance(module,
                            (QKVParallelLinear, MergedColumnParallelLinear)):
                self.maybe_fused_weights_modules[name] = module.output_sizes
1035
1036
1037
1038
1039
            # In TP, these weights are partitioned along the column
            # dimension (dim=-1)
            elif isinstance(module, (RowParallelLinear, )):
                self.column_sharded_weights_modules.append(name)

1040
1041
        self.model_type = type(model).__name__

1042
1043
1044
        logger.info("Loading weights with BitsAndBytes quantization. "
                    " May take a while ...")

1045
1046
        quant_config = getattr(model_config.hf_config, "quantization_config",
                               None)
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057

        pre_quant = False
        if quant_config is not None:
            quant_method = quant_config.get('quant_method')
            if quant_method == "bitsandbytes":
                pre_quant = True
            else:
                raise ValueError(
                    f"BitsAndBytes loader does not support {quant_method} "
                    "quantization")

1058
1059
1060
1061
1062
1063
1064
        # The quant_states in pre_quantized models cannot work with a split
        # weight tensor. So TP does not work with pre_quantized bnb models.
        if pre_quant and get_tensor_model_parallel_world_size() > 1:
            raise ValueError(
                "Prequant BitsAndBytes models with TP is not supported."
                "Please try with PP.")

1065
1066
1067
        load_8bit = False
        if pre_quant:
            load_8bit = quant_config.get('load_in_8bit', False)
1068
1069
1070

        qweight_iterator, quant_state_dict = \
            self._get_quantized_weights_iterator(
1071
            model_config.model, model_config.revision, pre_quant, load_8bit)
1072
1073
1074

        model.load_weights(qweight_iterator)

1075
1076
        torch.cuda.empty_cache()

1077
1078
        param_dict = dict(model.named_parameters())
        stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1079
1080
1081
        # TODO: Change this lazy import to normal import
        # after the checks are updated to run on a new version
        from vllm.model_executor.models.utils import is_pp_missing_parameter
1082
        for quant_param_name in quant_state_dict:
1083
1084
1085
            if is_pp_missing_parameter(quant_param_name, model):
                continue

1086
1087
1088
1089
1090
1091
            non_stacked_param_name = quant_param_name

            shard_index = 0
            for shard_name, (
                    weight_name, index
            ) in model.bitsandbytes_stacked_params_mapping.items():
1092
1093
1094
1095
1096

                shard_pos = quant_param_name.find(shard_name)
                # Some models, such as MiniCPM V2.5/2.6, contain both
                # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
                # from being incorrectly identified as being present in
1097
                # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
1098
                if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
                    shard_index = index
                    quant_param_name = quant_param_name.replace(
                        shard_name, weight_name)
                    break

            if quant_param_name not in param_dict:
                raise ValueError(
                    f"Parameter {quant_param_name} not found in the model.")

            if quant_param_name not in stacked_quant_state_dict:
                stacked_quant_state_dict[quant_param_name] = {}

            stacked_quant_state_dict[quant_param_name][shard_index] = (
                quant_state_dict[non_stacked_param_name])

        # save quant_states and offsets as the attributes of the parameters
        for param_name, param in param_dict.items():
            if param_name in stacked_quant_state_dict:
                quant_states = stacked_quant_state_dict[param_name]
                set_weight_attrs(param, {"bnb_quant_state": quant_states})

                pack_ratio = getattr(param, "pack_factor", -1)
                if pack_ratio == -1:
                    raise ValueError(
                        f"pack_factor not set for parameter {param_name}.")

                num_elements = [0] * len(quant_states)
1126
                for seq, quant_state in quant_states.items():
1127
                    num_elements[seq] = math.prod(
1128
                        quant_state.shape) // pack_ratio
1129
1130
1131
1132

                offsets = np.concatenate(([0], np.cumsum(num_elements)))
                set_weight_attrs(param, {"bnb_shard_offsets": offsets})

1133
1134
1135
1136
                if load_8bit:
                    set_weight_attrs(
                        param, {"matmul_state": [None] * len(quant_states)})

1137
1138
1139
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

1140
1141
1142
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1143
1144
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1145
                model = _initialize_model(vllm_config=vllm_config)
1146
1147
1148
1149
1150
1151

                self._load_weights(model_config, model)

        return model.eval()


1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _prepare_weights(self, model_name_or_path: str):
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        else:
            raise ValueError(f"{model_name_or_path} is not a file.")

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        model_type = config.model_type
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        num_layers = config.num_hidden_layers
        name_map = gguf.get_tensor_name_map(arch, num_layers)
        with torch.device("meta"):
            dummy_model = AutoModelForCausalLM.from_config(config)
        state_dict = dummy_model.state_dict()

        gguf_to_hf_name_map = {}
        for hf_name in state_dict:
            name, suffix = hf_name.rsplit(".", 1)
            gguf_name = name_map.get_name(name)
            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
        return gguf_to_hf_name_map

    def _get_weights_iterator(
        self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        return gguf_quant_weights_iterator(model_name_or_path,
                                           gguf_to_hf_name_map)

1211
1212
1213
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model)

1214
1215
1216
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1217
1218
1219
1220
1221
1222
1223
1224
1225
        local_model_path = self._prepare_weights(model_config.model)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        if "lm_head.weight" in get_gguf_extra_tensor_names(
                local_model_path, gguf_weights_map):
            model_config.hf_config.update({"tie_word_embeddings": True})

        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1226
                model = _initialize_model(vllm_config=vllm_config)
1227
1228
1229
1230
1231
            model.load_weights(
                self._get_weights_iterator(local_model_path, gguf_weights_map))
        return model


1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

1244
1245
1246
    if load_config.load_format == LoadFormat.SHARDED_STATE:
        return ShardedStateLoader(load_config)

1247
1248
1249
    if load_config.load_format == LoadFormat.BITSANDBYTES:
        return BitsAndBytesModelLoader(load_config)

1250
1251
1252
    if load_config.load_format == LoadFormat.GGUF:
        return GGUFModelLoader(load_config)

1253
    return DefaultModelLoader(load_config)