loader.py 61.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
# ruff: noqa: SIM117
4
import collections
5
import copy
6
import dataclasses
7
import fnmatch
8
import glob
9
import inspect
10
import itertools
11
import math
12
import os
13
import warnings
14
from abc import ABC, abstractmethod
15
from contextlib import contextmanager
16
17
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
                    Tuple, cast)
18

19
import gguf
20
import huggingface_hub
21
import numpy as np
22
import torch
23
from huggingface_hub import HfApi
24
from torch import nn
25
from transformers import AutoModelForCausalLM
26
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
27

28
from vllm.attention import Attention
29
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
30
                         VllmConfig, set_current_vllm_config)
31
32
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
33
from vllm.envs import VLLM_USE_MODELSCOPE
34
from vllm.logger import init_logger
35
36
from vllm.model_executor.layers.linear import (LinearBase,
                                               MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               ReplicatedLinear,
39
                                               RowParallelLinear)
40
from vllm.model_executor.layers.quantization.base_config import (
41
    QuantizeMethodBase)
42
from vllm.model_executor.model_loader.tensorizer import (
43
    TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
44
    serialize_vllm_model, tensorizer_weights_iterator)
45
from vllm.model_executor.model_loader.utils import (ParamMapping,
46
                                                    configure_quant_config,
47
                                                    get_model_architecture,
48
49
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
50
51
    download_safetensors_index_file_from_hf, download_weights_from_hf,
    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
52
    get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
53
    initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
54
    runai_safetensors_weights_iterator, safetensors_weights_iterator)
55
from vllm.model_executor.utils import set_weight_attrs
56
from vllm.platforms import current_platform
57
from vllm.transformers_utils.s3_utils import glob as s3_glob
58
from vllm.transformers_utils.utils import is_s3
59
from vllm.utils import is_pin_memory_available
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

    original_device_states: Dict[str, torch.device] = {}

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
90
91
92
93
94
95
96
97
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
98
99
100
101
102
103
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched

104
105
106
107

logger = init_logger(__name__)


108
def _initialize_model(
109
110
111
112
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
) -> nn.Module:
113
    """Initialize a model with the given configurations."""
114
    model_config = vllm_config.model_config
115
    model_class, _ = get_model_architecture(model_config)
116

117
118
119
    if vllm_config.quant_config is not None:
        configure_quant_config(vllm_config.quant_config, model_class)

120
    signatures = inspect.signature(model_class.__init__)
121
122
123
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
124
        with set_current_vllm_config(vllm_config, check_compile=True):
125
            return model_class(vllm_config=vllm_config, prefix=prefix)
126

127
128
129
    msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
           "input arguments. Possibly you have an old-style model class"
           " registered from out of tree and it is used for new vLLM version. "
130
           "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
131
           "for the design and update the model class accordingly.")
132
133
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

134
135
    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
136
137
        model_class,
    )
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
王敏's avatar
王敏 committed
152
153
    if "parallel_config" in all_params:
        kwargs["parallel_config"] = vllm_config.parallel_config
154
    with set_current_vllm_config(vllm_config, check_compile=True):
155
        return model_class(**kwargs)
156
157
158
159
160
161
162
163


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

164
165
166
167
168
    @abstractmethod
    def download_model(self, model_config: ModelConfig) -> None:
        """Download a model so that it can be immediately loaded."""
        raise NotImplementedError

169
    @abstractmethod
170
    def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
171
        """Load a model with the given configurations."""
172
        raise NotImplementedError
173
174
175
176
177


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

194
195
196
        allow_patterns_overrides: Optional[list[str]] = None
        """If defined, weights will load exclusively using these patterns."""

197
198
199
200
201
202
203
204
205
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
206

207
208
209
210
211
212
213
214
215
216
217
218
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
219
220
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
221
                    ignore_file_pattern=self.load_config.ignore_patterns,
222
                )
223
224
225
226
227
            else:
                model_path = model
            return model_path
        return None

228
229
230
231
232
    def _prepare_weights(
        self,
        model_name_or_path: str,
        revision: Optional[str],
        fall_back_to_pt: bool,
233
        allow_patterns_overrides: Optional[list[str]],
234
    ) -> Tuple[str, List[str], bool]:
235
236
237
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
238
239
        model_name_or_path = (self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path)
240
241
242
243

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
244
        index_file = SAFE_WEIGHTS_INDEX_NAME
245
246
247
248
249
250
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
251
252
253
254
        elif load_format == LoadFormat.MISTRAL:
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
255
256
257
258
259
260
261
262
263
264
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

265
266
267
        if allow_patterns_overrides is not None:
            allow_patterns = allow_patterns_overrides

268
        if not is_local:
269
270
271
272
273
274
275
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
276
277
278
279
280
281
282
283
284
285
286
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

287
288
289
290
291
292
293
294
        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
295
296
297
298
299
                    model_name_or_path,
                    index_file,
                    self.load_config.download_dir,
                    revision,
                )
300
            hf_weights_files = filter_duplicate_safetensors_files(
301
                hf_weights_files, hf_folder, index_file)
302
        else:
303
304
305
306
307
308
309
310
311
312
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
313
            self, source: "Source"
314
315
316
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
317
318
            source.model_or_path, source.revision, source.fall_back_to_pt,
            source.allow_patterns_overrides)
319
320
321
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
322
            weights_iterator = np_cache_weights_iterator(
323
324
325
326
327
                source.model_or_path,
                self.load_config.download_dir,
                hf_folder,
                hf_weights_files,
            )
328
329
330
331
332
        elif use_safetensors:
            weights_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weights_iterator = pt_weights_iterator(hf_weights_files)

333
        if current_platform.is_tpu():
334
335
336
337
338
339
340
341
342
343
            # In PyTorch XLA, we should call `xm.mark_step` frequently so that
            # not too many ops are accumulated in the XLA program.
            import torch_xla.core.xla_model as xm

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    xm.mark_step()

            weights_iterator = _xla_weights_iterator(weights_iterator)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

        # Apply the prefix.
        return ((source.prefix + name, tensor)
                for (name, tensor) in weights_iterator)

    def _get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
359
                                    True),
360
361
            allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
                                             None),
362
        )
363
364
        yield from self._get_weights_iterator(primary_weights)

365
366
367
368
        secondary_weights = cast(
            Iterable[DefaultModelLoader.Source],
            getattr(model, "secondary_weights", ()),
        )
369
370
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)
371

372
373
374
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model,
                              model_config.revision,
375
376
                              fall_back_to_pt=True,
                              allow_patterns_overrides=None)
377

378
379
380
381
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

382
        target_device = torch.device(device_config.device)
383
        with set_default_torch_dtype(model_config.dtype):
384
            with target_device:
385
                model = _initialize_model(vllm_config=vllm_config)
386

387
388
389
            weights_to_load = {name for name, _ in model.named_parameters()}
            loaded_weights = model.load_weights(
                self._get_all_weights(model_config, model))
390
            # We only enable strict check for non-quantized models
391
392
393
394
395
396
397
            # that have loaded weights tracking currently.
            if model_config.quantization is None and loaded_weights is not None:
                weights_not_loaded = weights_to_load - loaded_weights
                if weights_not_loaded:
                    raise ValueError(
                        "Following weights were not initialized from "
                        f"checkpoint: {weights_not_loaded}")
398

399
            for _, module in model.named_modules():
400
                quant_method = getattr(module, "quant_method", None)
zhuwenwen's avatar
zhuwenwen committed
401
                if isinstance(quant_method, QuantizeMethodBase) and quant_method != "awq" and quant_method != "gptq" and quant_method != "compressed_tensors":
402
403
404
405
406
407
408
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
409
                if isinstance(module, Attention) and \
410
411
412
                    hasattr(module, "process_weights_after_loading"):
                    # When attention modules need to process weights after
                    # currently only used by MLA
413
414
415
                    # TODO(lucas): see if there is a way to unify the signatures
                    # of process_weights_after_loading
                    module.process_weights_after_loading(model_config.dtype)
416
417
418
419
420
421
422
423
424
425
426
427
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

428
429
430
    def download_model(self, model_config: ModelConfig) -> None:
        pass  # Nothing to download

431
432
433
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
434
435
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
436
                model = _initialize_model(vllm_config=vllm_config)
437
438
439
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
440
441
442
443
444
445
446
447
448
449
450
451

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(
                            module, torch.device(device_config.device)):
                        quant_method.process_weights_after_loading(module)
452
453
454
455
456
                if isinstance(module, Attention) and \
                    hasattr(module, "process_weights_after_loading"):
                    # When attention modules need to process weights after
                    # currently only used by MLA
                    module.process_weights_after_loading(model_config.dtype)
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
477
        self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
478
479
480
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

481
    def _load_model_serialized_cpu(
482
        self,
483
        vllm_config: VllmConfig,
484
    ) -> nn.Module:
485
        """Load a serialized model with tensorizer to the CPU.
486

487
        This is only necessary when the model isn't vLLM-tensorized (see
488
489
490
        examples/other/tensorize_vllm_model.py) This should still
        be faster than default HuggingFace loading, but will be slower than
        loading a vLLM-tensorized model.
491
        """
492
493
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
494
495
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
496
                model = _initialize_model(vllm_config=vllm_config)
497
498
499
500
501

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
502
        self,
503
        vllm_config: VllmConfig,
504
505
506
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

507
        Expects a vLLM-tensorized model. See the
508
        examples/other/tensorize_vllm_model.py example script
509
        for serializing vLLM models."""
510
511
512
513

        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

514
515
516
517
518
519
520
521
522
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

523
524
                model = load_with_tensorizer(tensorizer_config,
                                             vllm_config=vllm_config)
525
526
        return model.eval()

527
528
529
530
531
532
    def download_model(self, model_config: ModelConfig) -> None:
        self.tensorizer_config.verify_with_model_config(model_config)

        with self.tensorizer_config.open_stream():
            pass

533
534
535
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
536
537
        self._verify_config(model_config, parallel_config)

538
539
        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank
540
541
542
543

            self.tensorizer_config.tensorizer_uri = (
                self.tensorizer_config.tensorizer_uri %
                get_tensor_model_parallel_rank())
544

545
        if is_vllm_tensorized(self.tensorizer_config):
546
547
            return self._load_model_serialized(vllm_config=vllm_config)
        return self._load_model_serialized_cpu(vllm_config=vllm_config)
548

549
550
551
552
553
554
555
556
557
558
    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
        )

559

560
561
562
563
564
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
565
566
    `examples/offline_inference/save_sharded_state.py` for creating a sharded
    checkpoint.
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
583
        tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
584
585
586
587
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
588
589
        same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
            collections.defaultdict(list))
590
591
592
593
594
595
596
597
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

598
        result: Dict[str, torch.Tensor] = {}
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

617
618
619
620
621
622
    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
623
624
625
626
627
628
629
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
630

631
632
633
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

634
635
636
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
637
638
639
        from safetensors.torch import safe_open

        from vllm.distributed import get_tensor_model_parallel_rank
640
641
642
643

        local_model_path = self._prepare_weights(model_config.model,
                                                 model_config.revision)

644
645
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
646
                model = _initialize_model(vllm_config=vllm_config)
647
648
649
650
                for _, module in model.named_modules():
                    quant_method = getattr(module, "quant_method", None)
                    if quant_method is not None:
                        quant_method.process_weights_after_loading(module)
651
652
653
654
655
656
                    if isinstance(module, Attention) and \
                        hasattr(module, "process_weights_after_loading"):
                        # When attention modules need to process weights after
                        # currently only used by MLA
                        module.process_weights_after_loading(
                            model_config.dtype)
657
658
            rank = get_tensor_model_parallel_rank()
            pattern = os.path.join(
659
                local_model_path,
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
                self.pattern.format(rank=rank, part="*"),
            )
            filepaths = glob.glob(pattern)
            if not filepaths:
                # TODO: support un-sharded checkpoints too
                raise ValueError(
                    f"Could not find checkpoint files '{pattern}', only "
                    f"pre-sharded checkpoints are currently supported!")
            state_dict = self._filter_subtensors(model.state_dict())
            for path in filepaths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        # If loading with LoRA enabled, additional padding may
                        # be added to certain parameters. We only load into a
                        # narrowed view of the parameter data.
                        param_data = state_dict[key].data
                        param_shape = state_dict[key].shape
                        for dim, size in enumerate(tensor.shape):
                            if size < param_shape[dim]:
                                param_data = param_data.narrow(dim, 0, size)
                        if tensor.shape != param_shape:
                            logger.warning(
                                "loading tensor of shape %s into "
684
685
686
687
688
                                "parameter '%s' of shape %s",
                                tensor.shape,
                                key,
                                param_shape,
                            )
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                        param_data.copy_(tensor)
                        state_dict.pop(key)
            if state_dict:
                raise ValueError(
                    f"Missing keys {tuple(state_dict)} in loaded state!")
        return model.eval()

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: Dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )
733
734


735
736
737
738
739
740
741
742
class BitsAndBytesModelLoader(BaseModelLoader):
    """Model loader to load model weights with BitAndBytes quantization."""

    possible_config_file_names = ["adapter_config.json"]

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

743
744
745
746
        # Save the module names without sharding.
        self.unsharded_weights_modules: List[str] = []
        # Save the module names that are sharded by column.
        self.column_sharded_weights_modules: List[str] = []
747
748
749
        # Store all module names (from transformers) that support
        # BNB quantization.
        self.target_modules: List[str] = []
750
751
        # mapping weight names from transformers to vllm.
        self.weight_mapper: Callable = lambda name: name
752
753

    def _get_weight_files(
754
755
756
757
758
759
760
        self,
        model_name_or_path: str,
        allowed_patterns: List[str],
        revision: Optional[str] = None,
    ) -> Tuple[List[str], str]:
        """Retrieve weight files. Download the files if necessary.

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        Return the weight files and the file pattern."""
        is_local = os.path.isdir(model_name_or_path)

        if is_local:
            for pattern in allowed_patterns:
                weight_files = glob.glob(
                    os.path.join(model_name_or_path, pattern))
                if weight_files:
                    return weight_files, pattern
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
            for pattern in allowed_patterns:
                matching_files = fnmatch.filter(repo_files, pattern)
                if matching_files:
                    hf_folder = download_weights_from_hf(
777
778
779
780
781
782
                        model_name_or_path,
                        self.load_config.download_dir,
                        [pattern],
                        revision,
                        ignore_patterns=self.load_config.ignore_patterns,
                    )
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
                    return glob.glob(os.path.join(hf_folder, pattern)), pattern

        raise RuntimeError(
            f"No model weights found in: `{model_name_or_path}`")

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> Tuple[List[str], bool]:
        """Prepare weight files for the model."""

        allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

        hf_weights_files, matched_pattern = self._get_weight_files(
            model_name_or_path, allowed_patterns, revision)

        if matched_pattern != "*.safetensors":
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_weights_files, matched_pattern == "*.safetensors"

807
808
    def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
        if use_safetensors:
809
            iterator = safetensors_weights_iterator(hf_weights_files)
810
        else:
811
            iterator = pt_weights_iterator(hf_weights_files)
812
813
814
815
816
        for org_name, param in iterator:
            # mapping weight names from transformers to vllm while preserving
            # original names.
            mapped_name = self.weight_mapper(org_name)
            yield org_name, mapped_name, param
817

818
    def _get_quantized_weights_iterator(
819
820
821
822
823
        self,
        model_name_or_path: str,
        revision: Optional[str],
        pre_quant: bool,
        load_8bit: bool,
824
825
826
827
828
829
830
831
    ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
                                                                     Any]]:
        """Get an iterator to the model weights with bitsandbytes quantization,
        as well as the quantization state dictionary."""

        # only load the bitsandbytes module when needed
        try:
            import bitsandbytes
832

833
            if bitsandbytes.__version__ < "0.45.0":
834
                raise ImportError("bitsandbytes version is wrong. Please "
835
                                  "install bitsandbytes>=0.45.0.")
836
        except ImportError as err:
837
838
            raise ImportError("Please install bitsandbytes>=0.45.0 via "
                              "`pip install bitsandbytes>=0.45.0` to use "
839
840
841
842
843
                              "bitsandbytes quantizer.") from err

        hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision)

844
        quant_state_dict: Dict[str, Any] = {}
845

846
847
848
849
850
851
852
853
854
        if pre_quant:
            if load_8bit:
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
855

856
857
        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
858

859
860
861
862
863
864
865
    def _is_8bit_weight_name(self, weight_name: str):
        quantized_suffix = {".scb", ".weight_format"}
        return any(weight_name.lower().endswith(suffix)
                   for suffix in quantized_suffix)

    def _is_4bit_weight_name(self, weight_name: str):
        quantized_suffix = {
866
867
868
869
870
            "absmax",
            "quant_map",
            "nested_absmax",
            "nested_quant_map",
            "bitsandbytes",
871
872
873
874
        }
        suffix = weight_name.split(".")[-1]
        return any(q_suffix in suffix for q_suffix in quantized_suffix)

875
876
    def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
877
878
879
880
881
882
        for (
                org_weight_name,
                mapped_weight_name,
                weight_tensor,
        ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
            if not mapped_weight_name.lower().endswith(".scb"):
883
884
                continue

885
            weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
886
887
            quant_state_dict[weight_key] = weight_tensor

888
889
890
891
892
893
        for (
                org_weight_name,
                mapped_weight_name,
                weight_tensor,
        ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
            if self._is_8bit_weight_name(mapped_weight_name):
894
895
                continue

896
            if mapped_weight_name in quant_state_dict:
897
                set_weight_attrs(weight_tensor, {"load_in_8bit": True})
898
                yield org_weight_name, weight_tensor
899
            else:
900
                yield org_weight_name, weight_tensor
901
902
903
904
905
906
907
908
909

    def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        from bitsandbytes.functional import QuantState

        # First iterate over all quant state weights
        weight_iterator = self._hf_weight_iter(hf_weights_files,
                                               use_safetensors)
        temp_state_dict = {}
910
911
912
913
914
915
        for (
                org_weight_name,
                mapped_weight_name,
                weight_tensor,
        ) in weight_iterator:
            if not self._is_4bit_weight_name(mapped_weight_name):
916
917
918
                continue
            # bitsandbytes library requires
            # weight.quant_state.bitsandbytes__* in CPU
919
920
            if "quant_state.bitsandbytes" in mapped_weight_name:
                temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
921
            else:
922
                temp_state_dict[mapped_weight_name] = weight_tensor
923
924
925
926
927
928
929
930
931
932
933
934
935

        # Closure to parse quant_state for each prequant weight
        def _parse_quant_state(param_name: str,
                               temp_state_dict: Dict) -> QuantState:
            quant_state = {}
            for k in temp_state_dict:
                if param_name + "." in k:
                    quant_state[k] = temp_state_dict[k]

            return QuantState.from_dict(quant_state, device="cuda")

        # Second iterate over all prequant and normal weights
        # pre quantized weights would have a quant_state
936
937
938
939
940
941
        for (
                org_weight_name,
                mapped_weight_name,
                weight_tensor,
        ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
            if self._is_4bit_weight_name(mapped_weight_name):
942
                continue
943

944
            if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
945
                    in temp_state_dict) or (
946
                        f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
947
                        in temp_state_dict):
948
949
950
951
                quant_state = _parse_quant_state(mapped_weight_name,
                                                 temp_state_dict)
                quant_state_dict[mapped_weight_name] = quant_state
                yield org_weight_name, weight_tensor
952
            else:
953
                yield org_weight_name, weight_tensor
954
955
956
957

    def _unquantized_generator(self, hf_weights_files, use_safetensors,
                               quant_state_dict) -> Generator:
        from bitsandbytes.functional import quantize_4bit
958

959
960
961
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

962
963
964
965
966
967
968
969
        for (
                org_weight_name,
                mapped_weight_name,
                weight_tensor,
        ) in self._hf_weight_iter(hf_weights_files, use_safetensors):
            if any(target_module in mapped_weight_name
                   for target_module in self.target_modules
                   ) and mapped_weight_name.endswith(".weight"):
970
971
                # Without sharding
                if any(
972
                        mapped_weight_name.startswith(module)
973
974
975
                        for module in self.unsharded_weights_modules):
                    weight_sub_tensor = weight_tensor
                # Shard by column
976
                elif any(
977
                        mapped_weight_name.startswith(module)
978
                        for module in self.column_sharded_weights_modules):
979
980
981
982
983
                    total_size = weight_tensor.size(-1)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[...,
                                                      start_index:end_index]
984
985
986
                # Weights have fused on disk. In this case, we assume that the
                # weight and module use same name.
                elif any(
987
                        mapped_weight_name.startswith(module)
988
989
990
991
992
993
                        for module in self.maybe_fused_weights_modules):
                    # special case for fused weights
                    # get the size of each shard weight tensor
                    total_shard_sizes = next(
                        (sizes for module, sizes in
                         self.maybe_fused_weights_modules.items()
994
                         if mapped_weight_name.startswith(module)))
995
996
997
998
999
                    total_size = weight_tensor.size(0)
                    assert total_size == sum(total_shard_sizes)
                    # get the start/end index of each shard weight tensor
                    total_start_index = list(
                        itertools.accumulate([0] + total_shard_sizes))[:-1]
1000
1001
1002
1003
1004
                    shard_weights_index = [(
                        idx + size // tp_size * tp_rank,
                        idx + size // tp_size * (tp_rank + 1),
                    ) for idx, size in zip(total_start_index,
                                           total_shard_sizes)]
1005
1006
1007
1008
1009
1010
                    # slice and reorder the weight tensor
                    weight_tensor = [
                        weight_tensor[start_index:end_index, ...]
                        for start_index, end_index in shard_weights_index
                    ]
                    weight_sub_tensor = torch.cat(weight_tensor, dim=0)
1011
                # Shard by row
1012
1013
1014
1015
1016
1017
1018
                else:
                    total_size = weight_tensor.size(0)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[start_index:end_index,
                                                      ...]

1019
                # bitsandbytes requires data in GPU
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
                if weight_sub_tensor.is_cuda:
                    loaded_weight = weight_sub_tensor
                else:
                    loaded_weight = weight_sub_tensor.cuda()

                # remove the following after the issue is fixed:
                # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
                if loaded_weight.is_contiguous() is False:
                    loaded_weight = loaded_weight.contiguous()

1030
1031
1032
1033
                with set_default_torch_dtype(torch.float32):
                    processed_weight, quant_state = quantize_4bit(
                        loaded_weight,
                        compress_statistics=True,
1034
1035
                        quant_type="nf4",
                    )
1036

1037
                quant_state_dict[mapped_weight_name] = quant_state
1038
1039
            else:
                processed_weight = weight_tensor
1040
            yield org_weight_name, processed_weight
1041

1042
1043
1044
1045
    def _get_bnb_target_modules(self, model: nn.Module) -> None:

        for name, module in model.named_modules():
            if isinstance(module, (LinearBase, )):
1046
                if modules_info := self.modules_mapping.get_sub_modules(name):
1047
                    # Map vllm's names to transformers's names.
1048
                    rep_name, sub_modules = modules_info
1049
                    for sub_name in sub_modules:
1050
                        self.target_modules.append(
1051
                            name.replace(rep_name, sub_name))
1052
1053
1054
1055
1056
                # Add original module name even if the module has stacked map,
                # in case model has a mixture of disk-merged and disk-splitted
                # weights with same last name.
                self.target_modules.append(name)

1057
1058
1059
1060
        assert (self.target_modules
                ), "vllm currently does not support BNB quantization for"
        f" {type(model).__name__}"

1061
1062
    def _load_weights(self, model_config: ModelConfig,
                      model: nn.Module) -> None:
1063
        if not hasattr(model, "load_weights"):
1064
1065
            raise AttributeError(
                "The required method 'load_weights' is not defined in class"
1066
                f" {type(model).__name__}.")
1067

1068
        if not hasattr(model, "packed_modules_mapping"):
1069
            raise AttributeError(
1070
                f"Model {type(model).__name__} does not support BitsAndBytes "
1071
1072
1073
1074
                "quantization yet. No 'packed_modules_mapping' found.")

        self.modules_mapping = ParamMapping(
            copy.deepcopy(model.packed_modules_mapping))
1075

1076
1077
1078
1079
        # For some models like Molmo, we need to use hf_to_vllm_mapper
        # to ensure correct loading of weights.
        if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
            self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
1080

1081
1082
1083
        # Modules whose weights might have fused on disk
        # we need their output_sizes to make shard in flight correctly with TP
        self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
1084
        self._get_bnb_target_modules(model)
1085
1086
1087
1088
1089
1090
        for name, module in model.named_modules():
            # Some modules like `ReplicatedLinear` should not have their weights
            # sharded. The reason for implementing it this way is to avoid new
            # static variable in the model implementation.
            if isinstance(module, (ReplicatedLinear, )):
                self.unsharded_weights_modules.append(name)
1091
1092
1093
1094
1095
1096
            # `QKVParallelLinear` and `MergedColumnParallelLinear` might have
            # fused weights on disk. We need to use the output sizes of these
            # modules to shard the weights correctly.
            elif isinstance(module,
                            (QKVParallelLinear, MergedColumnParallelLinear)):
                self.maybe_fused_weights_modules[name] = module.output_sizes
1097
1098
1099
1100
1101
            # In TP, these weights are partitioned along the column
            # dimension (dim=-1)
            elif isinstance(module, (RowParallelLinear, )):
                self.column_sharded_weights_modules.append(name)

1102
1103
        self.model_type = type(model).__name__

1104
1105
1106
        logger.info("Loading weights with BitsAndBytes quantization. "
                    " May take a while ...")

1107
1108
        quant_config = getattr(model_config.hf_config, "quantization_config",
                               None)
1109
1110
1111

        pre_quant = False
        if quant_config is not None:
1112
            quant_method = quant_config.get("quant_method")
1113
1114
1115
1116
1117
1118
1119
            if quant_method == "bitsandbytes":
                pre_quant = True
            else:
                raise ValueError(
                    f"BitsAndBytes loader does not support {quant_method} "
                    "quantization")

1120
1121
1122
1123
        # The quant_states in pre_quantized models cannot work with a split
        # weight tensor. So TP does not work with pre_quantized bnb models.
        if pre_quant and get_tensor_model_parallel_world_size() > 1:
            raise ValueError(
1124
1125
                "Prequant BitsAndBytes models with tensor parallelism is not "
                "supported. Please try with pipeline parallelism.")
1126

1127
1128
        load_8bit = False
        if pre_quant:
1129
            load_8bit = quant_config.get("load_in_8bit", False)
1130

1131
1132
1133
1134
        qweight_iterator, quant_state_dict = (
            self._get_quantized_weights_iterator(model_config.model,
                                                 model_config.revision,
                                                 pre_quant, load_8bit))
1135

1136
1137
1138
1139
1140
1141
1142
1143
        weights_to_load = {name for name, _ in model.named_parameters()}
        loaded_weights = model.load_weights(qweight_iterator)
        # Some models may have weights loading tracker unimplemented.
        if loaded_weights is not None:
            weights_not_loaded = weights_to_load - loaded_weights
            if weights_not_loaded:
                raise ValueError("Following weights were not initialized from "
                                 f"checkpoint: {weights_not_loaded}")
1144

1145
1146
        torch.cuda.empty_cache()

1147
1148
        param_dict = dict(model.named_parameters())
        stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1149
1150
1151
        # TODO: Change this lazy import to normal import
        # after the checks are updated to run on a new version
        from vllm.model_executor.models.utils import is_pp_missing_parameter
1152

1153
        for quant_param_name in quant_state_dict:
1154
1155
1156
            if is_pp_missing_parameter(quant_param_name, model):
                continue

1157
1158
1159
1160
            non_stacked_param_name = quant_param_name

            shard_index = 0
            for shard_name, (
1161
1162
                    weight_name,
                    index,
1163
            ) in self.modules_mapping.inverse_packed_mapping.items():
1164
1165
1166
                # Some models, such as MiniCPM V2.5/2.6, contain both
                # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
                # from being incorrectly identified as being present in
1167
                # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
1168
                shard_pos = quant_param_name.find(shard_name)
1169
1170
1171
                can_correct_rename = (shard_pos
                                      > 0) and (quant_param_name[shard_pos - 1]
                                                == ".")
1172
1173
1174
1175
1176
1177
1178
                # If the quant_param_name is packed, it won't occur in the
                # param_dict before renaming.
                new_quant_param_name = quant_param_name.replace(
                    shard_name, weight_name)
                need_rename = (quant_param_name not in param_dict) \
                              and (new_quant_param_name in param_dict)
                if can_correct_rename and need_rename:
1179
                    shard_index = index
1180
                    quant_param_name = new_quant_param_name
1181
1182
                    break

1183
1184
            # Models like Clip/Siglip may skip some layers in initialization,
            # causing unused quant_param_name in state_dict.
1185
            if quant_param_name not in param_dict:
1186
                continue
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205

            if quant_param_name not in stacked_quant_state_dict:
                stacked_quant_state_dict[quant_param_name] = {}

            stacked_quant_state_dict[quant_param_name][shard_index] = (
                quant_state_dict[non_stacked_param_name])

        # save quant_states and offsets as the attributes of the parameters
        for param_name, param in param_dict.items():
            if param_name in stacked_quant_state_dict:
                quant_states = stacked_quant_state_dict[param_name]
                set_weight_attrs(param, {"bnb_quant_state": quant_states})

                pack_ratio = getattr(param, "pack_factor", -1)
                if pack_ratio == -1:
                    raise ValueError(
                        f"pack_factor not set for parameter {param_name}.")

                num_elements = [0] * len(quant_states)
1206
                for seq, quant_state in quant_states.items():
1207
1208
                    num_elements[seq] = (math.prod(quant_state.shape) //
                                         pack_ratio)
1209
1210
1211
1212

                offsets = np.concatenate(([0], np.cumsum(num_elements)))
                set_weight_attrs(param, {"bnb_shard_offsets": offsets})

1213
1214
1215
1216
                if load_8bit:
                    set_weight_attrs(
                        param, {"matmul_state": [None] * len(quant_states)})

1217
1218
1219
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

1220
1221
1222
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1223
1224
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1225
                model = _initialize_model(vllm_config=vllm_config)
1226
1227
1228
1229
1230
1231

                self._load_weights(model_config, model)

        return model.eval()


1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _prepare_weights(self, model_name_or_path: str):
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        else:
            raise ValueError(f"{model_name_or_path} is not a file.")

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        model_type = config.model_type
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        num_layers = config.num_hidden_layers
        name_map = gguf.get_tensor_name_map(arch, num_layers)
        with torch.device("meta"):
            dummy_model = AutoModelForCausalLM.from_config(config)
        state_dict = dummy_model.state_dict()

        gguf_to_hf_name_map = {}
        for hf_name in state_dict:
            name, suffix = hf_name.rsplit(".", 1)
            gguf_name = name_map.get_name(name)
            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
        return gguf_to_hf_name_map

    def _get_weights_iterator(
        self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        return gguf_quant_weights_iterator(model_name_or_path,
                                           gguf_to_hf_name_map)

1291
1292
1293
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model)

1294
1295
1296
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1297
1298
1299
1300
1301
1302
1303
1304
1305
        local_model_path = self._prepare_weights(model_config.model)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        if "lm_head.weight" in get_gguf_extra_tensor_names(
                local_model_path, gguf_weights_map):
            model_config.hf_config.update({"tie_word_embeddings": True})

        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1306
                model = _initialize_model(vllm_config=vllm_config)
1307
1308
1309
1310
1311
            model.load_weights(
                self._get_weights_iterator(local_model_path, gguf_weights_map))
        return model


1312
1313
class RunaiModelStreamerLoader(BaseModelLoader):
    """
1314
        Model loader that can load safetensors
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
        files from local FS or S3 bucket.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            extra_config = load_config.model_loader_extra_config

            if ("concurrency" in extra_config
                    and isinstance(extra_config.get("concurrency"), int)):
                os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
                    extra_config.get("concurrency"))

            if ("memory_limit" in extra_config
                    and isinstance(extra_config.get("memory_limit"), int)):
                os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
                    extra_config.get("memory_limit"))

            runai_streamer_s3_endpoint = os.getenv(
                'RUNAI_STREAMER_S3_ENDPOINT')
            aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
            if (runai_streamer_s3_endpoint is None
                    and aws_endpoint_url is not None):
                os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> List[str]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        is_s3_path = is_s3(model_name_or_path)
        is_local = os.path.isdir(model_name_or_path)
        safetensors_pattern = "*.safetensors"
        index_file = SAFE_WEIGHTS_INDEX_NAME

        hf_folder = (model_name_or_path if
                     (is_local or is_s3_path) else download_weights_from_hf(
                         model_name_or_path,
                         self.load_config.download_dir,
                         [safetensors_pattern],
                         revision,
                         ignore_patterns=self.load_config.ignore_patterns,
                     ))

        if is_s3_path:
            hf_weights_files = s3_glob(path=hf_folder,
                                       allow_pattern=[safetensors_pattern])
        else:
            hf_weights_files = glob.glob(
                os.path.join(hf_folder, safetensors_pattern))

        if not is_local and not is_s3_path:
            download_safetensors_index_file_from_hf(
                model_name_or_path, index_file, self.load_config.download_dir,
                revision)

        if not hf_weights_files:
            raise RuntimeError(
                f"Cannot find any safetensors model weights with "
                f"`{model_name_or_path}`")

        return hf_weights_files

    def _get_weights_iterator(
            self, model_or_path: str,
            revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_weights_files = self._prepare_weights(model_or_path, revision)
        return runai_safetensors_weights_iterator(hf_weights_files)

    def download_model(self, model_config: ModelConfig) -> None:
        """Download model if necessary"""
        self._prepare_weights(model_config.model, model_config.revision)

    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        """Perform streaming of the model to destination"""
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

        target_device = torch.device(device_config.device)
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = _initialize_model(vllm_config=vllm_config)

            model_weights = model_config.model
            if hasattr(model_config, "model_weights"):
                model_weights = model_config.model_weights
            model.load_weights(
                self._get_weights_iterator(model_weights,
                                           model_config.revision))

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
1411
1412
1413
1414
1415
                if isinstance(module, Attention) and \
                    hasattr(module, "process_weights_after_loading"):
                    # When attention modules need to process weights after
                    # currently only used by MLA
                    module.process_weights_after_loading(model_config.dtype)
1416
1417
1418
        return model.eval()


1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

1431
1432
1433
    if load_config.load_format == LoadFormat.SHARDED_STATE:
        return ShardedStateLoader(load_config)

1434
1435
1436
    if load_config.load_format == LoadFormat.BITSANDBYTES:
        return BitsAndBytesModelLoader(load_config)

1437
1438
1439
    if load_config.load_format == LoadFormat.GGUF:
        return GGUFModelLoader(load_config)

1440
1441
1442
    if load_config.load_format == LoadFormat.RUNAI_STREAMER:
        return RunaiModelStreamerLoader(load_config)

1443
    return DefaultModelLoader(load_config)