loader.py 57.8 KB
Newer Older
1
# ruff: noqa: SIM117
2
import collections
3
import copy
4
import dataclasses
5
import fnmatch
6
import glob
7
import inspect
8
import itertools
9
import math
10
import os
11
import warnings
12
from abc import ABC, abstractmethod
13
from contextlib import contextmanager
14
15
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
                    Tuple, cast)
16

17
import gguf
18
import huggingface_hub
19
import numpy as np
20
import torch
21
from huggingface_hub import HfApi
22
from torch import nn
23
from transformers import AutoModelForCausalLM
24
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
25

26
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
27
                         VllmConfig, set_current_vllm_config)
28
29
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
30
from vllm.envs import VLLM_USE_MODELSCOPE
31
from vllm.logger import init_logger
32
33
from vllm.model_executor.layers.linear import (LinearBase,
                                               MergedColumnParallelLinear,
34
35
                                               QKVParallelLinear,
                                               ReplicatedLinear,
36
                                               RowParallelLinear)
37
38
from vllm.model_executor.layers.quantization.base_config import (
    QuantizeMethodBase)
39
from vllm.model_executor.model_loader.tensorizer import (
40
    TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
41
    serialize_vllm_model, tensorizer_weights_iterator)
42
43
44
from vllm.model_executor.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
45
46
    download_safetensors_index_file_from_hf, download_weights_from_hf,
    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
47
    get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
48
    initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
49
    runai_safetensors_weights_iterator, safetensors_weights_iterator)
50
from vllm.model_executor.utils import set_weight_attrs
51
from vllm.platforms import current_platform
52
from vllm.transformers_utils.s3_utils import glob as s3_glob
53
from vllm.transformers_utils.utils import is_s3
54
from vllm.utils import is_pin_memory_available
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

    original_device_states: Dict[str, torch.device] = {}

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
85
86
87
88
89
90
91
92
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
93
94
95
96
97
98
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched

99
100
101
102

logger = init_logger(__name__)


103
104
105
106
107
def _initialize_model(
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
) -> nn.Module:
108
    """Initialize a model with the given configurations."""
109
    model_config = vllm_config.model_config
110
    model_class, _ = get_model_architecture(model_config)
111

112
    signatures = inspect.signature(model_class.__init__)
113
114
115
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
116
117
        with set_current_vllm_config(vllm_config):
            return model_class(vllm_config=vllm_config, prefix=prefix)
118

119
120
121
    msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
           "input arguments. Possibly you have an old-style model class"
           " registered from out of tree and it is used for new vLLM version. "
122
           "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
123
           "for the design and update the model class accordingly.")
124
125
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

126
127
    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
128
129
        model_class,
    )
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
144
145
    with set_current_vllm_config(vllm_config):
        return model_class(**kwargs)
146
147
148
149
150
151
152
153


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

154
155
156
157
158
    @abstractmethod
    def download_model(self, model_config: ModelConfig) -> None:
        """Download a model so that it can be immediately loaded."""
        raise NotImplementedError

159
    @abstractmethod
160
    def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
161
        """Load a model with the given configurations."""
162
        raise NotImplementedError
163
164
165
166
167


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    @dataclasses.dataclass
    class Source:
        """A source for weights."""

        model_or_path: str
        """The model ID or path."""

        revision: Optional[str]
        """The optional model revision."""

        prefix: str = ""
        """A prefix to prepend to all weights."""

        fall_back_to_pt: bool = True
        """Whether .pt weights can be used."""

184
185
186
187
188
189
190
191
192
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
193

194
195
196
197
198
199
200
201
202
203
204
205
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
206
207
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
208
                    ignore_file_pattern=self.load_config.ignore_patterns,
209
                )
210
211
212
213
214
            else:
                model_path = model
            return model_path
        return None

215
216
217
218
219
220
    def _prepare_weights(
        self,
        model_name_or_path: str,
        revision: Optional[str],
        fall_back_to_pt: bool,
    ) -> Tuple[str, List[str], bool]:
221
222
223
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
224
225
        model_name_or_path = (self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path)
226
227
228
229

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
230
        index_file = SAFE_WEIGHTS_INDEX_NAME
231
232
233
234
235
236
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
237
238
239
240
        elif load_format == LoadFormat.MISTRAL:
            use_safetensors = True
            allow_patterns = ["consolidated*.safetensors"]
            index_file = "consolidated.safetensors.index.json"
241
242
243
244
245
246
247
248
249
250
251
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if not is_local:
252
253
254
255
256
257
258
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
259
260
261
262
263
264
265
266
267
268
269
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

270
271
272
273
274
275
276
277
        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
278
279
280
281
282
                    model_name_or_path,
                    index_file,
                    self.load_config.download_dir,
                    revision,
                )
283
            hf_weights_files = filter_duplicate_safetensors_files(
284
                hf_weights_files, hf_folder, index_file)
285
        else:
286
287
288
289
290
291
292
293
294
295
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
296
            self, source: "Source"
297
298
299
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
300
            source.model_or_path, source.revision, source.fall_back_to_pt)
301
302
303
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
304
            weights_iterator = np_cache_weights_iterator(
305
306
307
308
309
                source.model_or_path,
                self.load_config.download_dir,
                hf_folder,
                hf_weights_files,
            )
310
311
312
313
314
        elif use_safetensors:
            weights_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weights_iterator = pt_weights_iterator(hf_weights_files)

315
        if current_platform.is_tpu():
316
317
318
319
320
321
322
323
324
325
            # In PyTorch XLA, we should call `xm.mark_step` frequently so that
            # not too many ops are accumulated in the XLA program.
            import torch_xla.core.xla_model as xm

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    xm.mark_step()

            weights_iterator = _xla_weights_iterator(weights_iterator)
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

        # Apply the prefix.
        return ((source.prefix + name, tensor)
                for (name, tensor) in weights_iterator)

    def _get_all_weights(
        self,
        model_config: ModelConfig,
        model: nn.Module,
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        primary_weights = DefaultModelLoader.Source(
            model_config.model,
            model_config.revision,
            prefix="",
            fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
341
342
                                    True),
        )
343
344
        yield from self._get_weights_iterator(primary_weights)

345
346
347
348
        secondary_weights = cast(
            Iterable[DefaultModelLoader.Source],
            getattr(model, "secondary_weights", ()),
        )
349
350
        for source in secondary_weights:
            yield from self._get_weights_iterator(source)
351

352
353
354
355
356
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model,
                              model_config.revision,
                              fall_back_to_pt=True)

357
358
359
360
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

361
        target_device = torch.device(device_config.device)
362
        with set_default_torch_dtype(model_config.dtype):
363
            with target_device:
364
                model = _initialize_model(vllm_config=vllm_config)
365

366
367
368
            weights_to_load = {name for name, _ in model.named_parameters()}
            loaded_weights = model.load_weights(
                self._get_all_weights(model_config, model))
369
            # We only enable strict check for non-quantized models
370
371
372
373
374
375
376
            # that have loaded weights tracking currently.
            if model_config.quantization is None and loaded_weights is not None:
                weights_not_loaded = weights_to_load - loaded_weights
                if weights_not_loaded:
                    raise ValueError(
                        "Following weights were not initialized from "
                        f"checkpoint: {weights_not_loaded}")
377

378
            for _, module in model.named_modules():
379
                quant_method = getattr(module, "quant_method", None)
380
                if isinstance(quant_method, QuantizeMethodBase):
381
382
383
384
385
386
387
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
388
389
390
391
392
393
394
395
396
397
398
399
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

400
401
402
    def download_model(self, model_config: ModelConfig) -> None:
        pass  # Nothing to download

403
404
405
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
406
407
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
408
                model = _initialize_model(vllm_config=vllm_config)
409
410
411
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
412
413
414
415
416
417
418
419
420
421
422
423

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    # When quant methods need to process weights after loading
                    # (for repacking, quantizing, etc), they expect parameters
                    # to be on the global target device. This scope is for the
                    # case where cpu offloading is used, where we will move the
                    # parameters onto device for processing and back off after.
                    with device_loading_context(
                            module, torch.device(device_config.device)):
                        quant_method.process_weights_after_loading(module)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
444
        self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
445
446
447
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

448
    def _load_model_serialized_cpu(
449
        self,
450
        vllm_config: VllmConfig,
451
    ) -> nn.Module:
452
        """Load a serialized model with tensorizer to the CPU.
453

454
        This is only necessary when the model isn't vLLM-tensorized (see
455
456
457
        examples/other/tensorize_vllm_model.py) This should still
        be faster than default HuggingFace loading, but will be slower than
        loading a vLLM-tensorized model.
458
        """
459
460
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
461
462
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
463
                model = _initialize_model(vllm_config=vllm_config)
464
465
466
467
468

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
469
        self,
470
        vllm_config: VllmConfig,
471
472
473
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

474
        Expects a vLLM-tensorized model. See the
475
        examples/other/tensorize_vllm_model.py example script
476
        for serializing vLLM models."""
477
478
479
480

        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

481
482
483
484
485
486
487
488
489
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

490
491
                model = load_with_tensorizer(tensorizer_config,
                                             vllm_config=vllm_config)
492
493
        return model.eval()

494
495
496
497
498
499
    def download_model(self, model_config: ModelConfig) -> None:
        self.tensorizer_config.verify_with_model_config(model_config)

        with self.tensorizer_config.open_stream():
            pass

500
501
502
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
503
504
        self._verify_config(model_config, parallel_config)

505
506
        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank
507
508
509
510

            self.tensorizer_config.tensorizer_uri = (
                self.tensorizer_config.tensorizer_uri %
                get_tensor_model_parallel_rank())
511

512
        if is_vllm_tensorized(self.tensorizer_config):
513
514
            return self._load_model_serialized(vllm_config=vllm_config)
        return self._load_model_serialized_cpu(vllm_config=vllm_config)
515

516
517
518
519
520
521
522
523
524
525
    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
        )

526

527
528
529
530
531
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
532
533
    `examples/offline_inference/save_sharded_state.py` for creating a sharded
    checkpoint.
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
550
        tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
551
552
553
554
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
555
556
        same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
            collections.defaultdict(list))
557
558
559
560
561
562
563
564
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

565
        result: Dict[str, torch.Tensor] = {}
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

584
585
586
587
588
589
    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
590
591
592
593
594
595
596
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
597

598
599
600
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

601
602
603
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
604
605
606
        from safetensors.torch import safe_open

        from vllm.distributed import get_tensor_model_parallel_rank
607
608
609
610

        local_model_path = self._prepare_weights(model_config.model,
                                                 model_config.revision)

611
612
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
613
                model = _initialize_model(vllm_config=vllm_config)
614
615
616
617
                for _, module in model.named_modules():
                    quant_method = getattr(module, "quant_method", None)
                    if quant_method is not None:
                        quant_method.process_weights_after_loading(module)
618
619
            rank = get_tensor_model_parallel_rank()
            pattern = os.path.join(
620
                local_model_path,
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                self.pattern.format(rank=rank, part="*"),
            )
            filepaths = glob.glob(pattern)
            if not filepaths:
                # TODO: support un-sharded checkpoints too
                raise ValueError(
                    f"Could not find checkpoint files '{pattern}', only "
                    f"pre-sharded checkpoints are currently supported!")
            state_dict = self._filter_subtensors(model.state_dict())
            for path in filepaths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        # If loading with LoRA enabled, additional padding may
                        # be added to certain parameters. We only load into a
                        # narrowed view of the parameter data.
                        param_data = state_dict[key].data
                        param_shape = state_dict[key].shape
                        for dim, size in enumerate(tensor.shape):
                            if size < param_shape[dim]:
                                param_data = param_data.narrow(dim, 0, size)
                        if tensor.shape != param_shape:
                            logger.warning(
                                "loading tensor of shape %s into "
645
646
647
648
649
                                "parameter '%s' of shape %s",
                                tensor.shape,
                                key,
                                param_shape,
                            )
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                        param_data.copy_(tensor)
                        state_dict.pop(key)
            if state_dict:
                raise ValueError(
                    f"Missing keys {tuple(state_dict)} in loaded state!")
        return model.eval()

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: Dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )


696
697
698
class BitsAndBytesModelLoader(BaseModelLoader):
    """Model loader to load model weights with BitAndBytes quantization."""

699
700
    possible_config_file_names = ["adapter_config.json"]

701
702
703
    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

704
705
706
707
        # Save the module names without sharding.
        self.unsharded_weights_modules: List[str] = []
        # Save the module names that are sharded by column.
        self.column_sharded_weights_modules: List[str] = []
708
709
710
        # Store all module names (from transformers) that support
        # BNB quantization.
        self.target_modules: List[str] = []
711
712
        # mapping weight names from transformers to vllm.
        self.weight_mapper: Callable = lambda name: name
713
714

    def _get_weight_files(
715
716
717
718
719
720
721
        self,
        model_name_or_path: str,
        allowed_patterns: List[str],
        revision: Optional[str] = None,
    ) -> Tuple[List[str], str]:
        """Retrieve weight files. Download the files if necessary.

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        Return the weight files and the file pattern."""
        is_local = os.path.isdir(model_name_or_path)

        if is_local:
            for pattern in allowed_patterns:
                weight_files = glob.glob(
                    os.path.join(model_name_or_path, pattern))
                if weight_files:
                    return weight_files, pattern
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
            for pattern in allowed_patterns:
                matching_files = fnmatch.filter(repo_files, pattern)
                if matching_files:
                    hf_folder = download_weights_from_hf(
738
739
740
741
742
743
                        model_name_or_path,
                        self.load_config.download_dir,
                        [pattern],
                        revision,
                        ignore_patterns=self.load_config.ignore_patterns,
                    )
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
                    return glob.glob(os.path.join(hf_folder, pattern)), pattern

        raise RuntimeError(
            f"No model weights found in: `{model_name_or_path}`")

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> Tuple[List[str], bool]:
        """Prepare weight files for the model."""

        allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

        hf_weights_files, matched_pattern = self._get_weight_files(
            model_name_or_path, allowed_patterns, revision)

        if matched_pattern != "*.safetensors":
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_weights_files, matched_pattern == "*.safetensors"

768
769
    def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
        if use_safetensors:
770
            iterator = safetensors_weights_iterator(hf_weights_files)
771
        else:
772
773
774
775
            iterator = pt_weights_iterator(hf_weights_files)
        for name, param in iterator:
            # mapping weight names from transformers to vllm.
            yield self.weight_mapper(name), param
776

777
    def _get_quantized_weights_iterator(
778
779
780
781
782
        self,
        model_name_or_path: str,
        revision: Optional[str],
        pre_quant: bool,
        load_8bit: bool,
783
784
785
786
787
788
789
790
    ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
                                                                     Any]]:
        """Get an iterator to the model weights with bitsandbytes quantization,
        as well as the quantization state dictionary."""

        # only load the bitsandbytes module when needed
        try:
            import bitsandbytes
791

792
            if bitsandbytes.__version__ < "0.45.0":
793
                raise ImportError("bitsandbytes version is wrong. Please "
794
                                  "install bitsandbytes>=0.45.0.")
795
        except ImportError as err:
796
797
            raise ImportError("Please install bitsandbytes>=0.45.0 via "
                              "`pip install bitsandbytes>=0.45.0` to use "
798
799
800
801
802
                              "bitsandbytes quantizer.") from err

        hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision)

803
        quant_state_dict: Dict[str, Any] = {}
804

805
806
807
808
809
810
811
812
813
        if pre_quant:
            if load_8bit:
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
814

815
816
        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
817

818
819
820
821
822
823
824
    def _is_8bit_weight_name(self, weight_name: str):
        quantized_suffix = {".scb", ".weight_format"}
        return any(weight_name.lower().endswith(suffix)
                   for suffix in quantized_suffix)

    def _is_4bit_weight_name(self, weight_name: str):
        quantized_suffix = {
825
826
827
828
829
            "absmax",
            "quant_map",
            "nested_absmax",
            "nested_quant_map",
            "bitsandbytes",
830
831
832
833
        }
        suffix = weight_name.split(".")[-1]
        return any(q_suffix in suffix for q_suffix in quantized_suffix)

834
835
836
837
838
839
840
    def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
            if not weight_name.lower().endswith(".scb"):
                continue

841
            weight_key = weight_name.lower().replace(".scb", ".weight")
842
843
844
845
            quant_state_dict[weight_key] = weight_tensor

        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
846
            if self._is_8bit_weight_name(weight_name):
847
848
                continue

849
            if weight_name in quant_state_dict:
850
                set_weight_attrs(weight_tensor, {"load_in_8bit": True})
851
                yield weight_name, weight_tensor
852
853
854
855
856
857
858
859
860
861
862
863
            else:
                yield weight_name, weight_tensor

    def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
                                  quant_state_dict) -> Generator:
        from bitsandbytes.functional import QuantState

        # First iterate over all quant state weights
        weight_iterator = self._hf_weight_iter(hf_weights_files,
                                               use_safetensors)
        temp_state_dict = {}
        for weight_name, weight_tensor in weight_iterator:
864
            if not self._is_4bit_weight_name(weight_name):
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
                continue
            # bitsandbytes library requires
            # weight.quant_state.bitsandbytes__* in CPU
            if "quant_state.bitsandbytes" in weight_name:
                temp_state_dict[weight_name] = weight_tensor.cpu().data
            else:
                temp_state_dict[weight_name] = weight_tensor

        # Closure to parse quant_state for each prequant weight
        def _parse_quant_state(param_name: str,
                               temp_state_dict: Dict) -> QuantState:
            quant_state = {}
            for k in temp_state_dict:
                if param_name + "." in k:
                    quant_state[k] = temp_state_dict[k]

            return QuantState.from_dict(quant_state, device="cuda")

        # Second iterate over all prequant and normal weights
        # pre quantized weights would have a quant_state
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
887
            if self._is_4bit_weight_name(weight_name):
888
                continue
889

890
891
892
893
            if (f"{weight_name}.quant_state.bitsandbytes__nf4"
                    in temp_state_dict) or (
                        f"{weight_name}.quant_state.bitsandbytes__fp4"
                        in temp_state_dict):
894
895
                quant_state = _parse_quant_state(weight_name, temp_state_dict)
                quant_state_dict[weight_name] = quant_state
896
                yield weight_name, weight_tensor
897
898
899
900
901
902
            else:
                yield weight_name, weight_tensor

    def _unquantized_generator(self, hf_weights_files, use_safetensors,
                               quant_state_dict) -> Generator:
        from bitsandbytes.functional import quantize_4bit
903

904
905
906
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

907
908
        for weight_name, weight_tensor in self._hf_weight_iter(
                hf_weights_files, use_safetensors):
909
910
            if any(target_module in weight_name for target_module in
                   self.target_modules) and weight_name.endswith(".weight"):
911
912
913
914
915
916
                # Without sharding
                if any(
                        weight_name.startswith(module)
                        for module in self.unsharded_weights_modules):
                    weight_sub_tensor = weight_tensor
                # Shard by column
917
918
919
                elif any(
                        weight_name.startswith(module)
                        for module in self.column_sharded_weights_modules):
920
921
922
923
924
                    total_size = weight_tensor.size(-1)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[...,
                                                      start_index:end_index]
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
                # Weights have fused on disk. In this case, we assume that the
                # weight and module use same name.
                elif any(
                        weight_name.startswith(module)
                        for module in self.maybe_fused_weights_modules):
                    # special case for fused weights
                    # get the size of each shard weight tensor
                    total_shard_sizes = next(
                        (sizes for module, sizes in
                         self.maybe_fused_weights_modules.items()
                         if weight_name.startswith(module)))
                    total_size = weight_tensor.size(0)
                    assert total_size == sum(total_shard_sizes)
                    # get the start/end index of each shard weight tensor
                    total_start_index = list(
                        itertools.accumulate([0] + total_shard_sizes))[:-1]
941
942
943
944
945
                    shard_weights_index = [(
                        idx + size // tp_size * tp_rank,
                        idx + size // tp_size * (tp_rank + 1),
                    ) for idx, size in zip(total_start_index,
                                           total_shard_sizes)]
946
947
948
949
950
951
                    # slice and reorder the weight tensor
                    weight_tensor = [
                        weight_tensor[start_index:end_index, ...]
                        for start_index, end_index in shard_weights_index
                    ]
                    weight_sub_tensor = torch.cat(weight_tensor, dim=0)
952
                # Shard by row
953
954
955
956
957
958
959
                else:
                    total_size = weight_tensor.size(0)
                    start_index = total_size // tp_size * tp_rank
                    end_index = total_size // tp_size * (tp_rank + 1)
                    weight_sub_tensor = weight_tensor[start_index:end_index,
                                                      ...]

960
                # bitsandbytes requires data in GPU
961
962
963
964
965
966
967
968
969
970
                if weight_sub_tensor.is_cuda:
                    loaded_weight = weight_sub_tensor
                else:
                    loaded_weight = weight_sub_tensor.cuda()

                # remove the following after the issue is fixed:
                # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
                if loaded_weight.is_contiguous() is False:
                    loaded_weight = loaded_weight.contiguous()

971
972
973
974
                with set_default_torch_dtype(torch.float32):
                    processed_weight, quant_state = quantize_4bit(
                        loaded_weight,
                        compress_statistics=True,
975
976
                        quant_type="nf4",
                    )
977
978
979
980
981
982

                quant_state_dict[weight_name] = quant_state
            else:
                processed_weight = weight_tensor

            yield weight_name, processed_weight
983

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    def _get_bnb_target_modules(self, model: nn.Module) -> None:

        # TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
        # packed_modules_mapping.
        inverse_stacked_mapping: Dict[str, List[str]] = {}
        for orig, (
                packed,
                idx,
        ) in model.bitsandbytes_stacked_params_mapping.items():
            if packed not in inverse_stacked_mapping:
                inverse_stacked_mapping[packed] = []
            inverse_stacked_mapping[packed].insert(idx, orig)

        for name, module in model.named_modules():
            if isinstance(module, (LinearBase, )):
                last_name = name.split(".")[-1]
                if sub_modules := inverse_stacked_mapping.get(last_name, []):
1001
                    # Map vllm's names to transformers's names.
1002
                    for sub_name in sub_modules:
1003
                        self.target_modules.append(
1004
                            name.replace(last_name, sub_name))
1005
1006
1007
1008
1009
                # Add original module name even if the module has stacked map,
                # in case model has a mixture of disk-merged and disk-splitted
                # weights with same last name.
                self.target_modules.append(name)

1010
1011
1012
1013
        assert (self.target_modules
                ), "vllm currently does not support BNB quantization for"
        f" {type(model).__name__}"

1014
1015
    def _load_weights(self, model_config: ModelConfig,
                      model: nn.Module) -> None:
1016
        if not hasattr(model, "load_weights"):
1017
1018
            raise AttributeError(
                "The required method 'load_weights' is not defined in class"
1019
                f" {type(model).__name__}.")
1020

1021
        if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
1022
            raise AttributeError(
1023
                f"Model {type(model).__name__} does not support BitsAndBytes "
1024
1025
                "quantization yet.")

1026
1027
1028
1029
        # For some models like Molmo, we need to use hf_to_vllm_mapper
        # to ensure correct loading of weights.
        if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
            self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
1030
1031
1032
        # Modules whose weights might have fused on disk
        # we need their output_sizes to make shard in flight correctly with TP
        self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
1033
        self._get_bnb_target_modules(model)
1034
1035
1036
1037
1038
1039
        for name, module in model.named_modules():
            # Some modules like `ReplicatedLinear` should not have their weights
            # sharded. The reason for implementing it this way is to avoid new
            # static variable in the model implementation.
            if isinstance(module, (ReplicatedLinear, )):
                self.unsharded_weights_modules.append(name)
1040
1041
1042
1043
1044
1045
            # `QKVParallelLinear` and `MergedColumnParallelLinear` might have
            # fused weights on disk. We need to use the output sizes of these
            # modules to shard the weights correctly.
            elif isinstance(module,
                            (QKVParallelLinear, MergedColumnParallelLinear)):
                self.maybe_fused_weights_modules[name] = module.output_sizes
1046
1047
1048
1049
1050
            # In TP, these weights are partitioned along the column
            # dimension (dim=-1)
            elif isinstance(module, (RowParallelLinear, )):
                self.column_sharded_weights_modules.append(name)

1051
1052
        self.model_type = type(model).__name__

1053
1054
1055
        logger.info("Loading weights with BitsAndBytes quantization. "
                    " May take a while ...")

1056
1057
        quant_config = getattr(model_config.hf_config, "quantization_config",
                               None)
1058
1059
1060

        pre_quant = False
        if quant_config is not None:
1061
            quant_method = quant_config.get("quant_method")
1062
1063
1064
1065
1066
1067
1068
            if quant_method == "bitsandbytes":
                pre_quant = True
            else:
                raise ValueError(
                    f"BitsAndBytes loader does not support {quant_method} "
                    "quantization")

1069
1070
1071
1072
1073
1074
1075
        # The quant_states in pre_quantized models cannot work with a split
        # weight tensor. So TP does not work with pre_quantized bnb models.
        if pre_quant and get_tensor_model_parallel_world_size() > 1:
            raise ValueError(
                "Prequant BitsAndBytes models with TP is not supported."
                "Please try with PP.")

1076
1077
        load_8bit = False
        if pre_quant:
1078
            load_8bit = quant_config.get("load_in_8bit", False)
1079

1080
1081
1082
1083
        qweight_iterator, quant_state_dict = (
            self._get_quantized_weights_iterator(model_config.model,
                                                 model_config.revision,
                                                 pre_quant, load_8bit))
1084

1085
1086
1087
1088
1089
1090
1091
1092
        weights_to_load = {name for name, _ in model.named_parameters()}
        loaded_weights = model.load_weights(qweight_iterator)
        # Some models may have weights loading tracker unimplemented.
        if loaded_weights is not None:
            weights_not_loaded = weights_to_load - loaded_weights
            if weights_not_loaded:
                raise ValueError("Following weights were not initialized from "
                                 f"checkpoint: {weights_not_loaded}")
1093

1094
1095
        torch.cuda.empty_cache()

1096
1097
        param_dict = dict(model.named_parameters())
        stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
1098
1099
1100
        # TODO: Change this lazy import to normal import
        # after the checks are updated to run on a new version
        from vllm.model_executor.models.utils import is_pp_missing_parameter
1101

1102
        for quant_param_name in quant_state_dict:
1103
1104
1105
            if is_pp_missing_parameter(quant_param_name, model):
                continue

1106
1107
1108
1109
            non_stacked_param_name = quant_param_name

            shard_index = 0
            for shard_name, (
1110
1111
                    weight_name,
                    index,
1112
            ) in model.bitsandbytes_stacked_params_mapping.items():
1113
1114
1115
1116
                shard_pos = quant_param_name.find(shard_name)
                # Some models, such as MiniCPM V2.5/2.6, contain both
                # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
                # from being incorrectly identified as being present in
1117
                # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
1118
                if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
1119
1120
1121
1122
1123
                    shard_index = index
                    quant_param_name = quant_param_name.replace(
                        shard_name, weight_name)
                    break

1124
1125
            # Models like Clip/Siglip may skip some layers in initialization,
            # causing unused quant_param_name in state_dict.
1126
            if quant_param_name not in param_dict:
1127
                continue
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

            if quant_param_name not in stacked_quant_state_dict:
                stacked_quant_state_dict[quant_param_name] = {}

            stacked_quant_state_dict[quant_param_name][shard_index] = (
                quant_state_dict[non_stacked_param_name])

        # save quant_states and offsets as the attributes of the parameters
        for param_name, param in param_dict.items():
            if param_name in stacked_quant_state_dict:
                quant_states = stacked_quant_state_dict[param_name]
                set_weight_attrs(param, {"bnb_quant_state": quant_states})

                pack_ratio = getattr(param, "pack_factor", -1)
                if pack_ratio == -1:
                    raise ValueError(
                        f"pack_factor not set for parameter {param_name}.")

                num_elements = [0] * len(quant_states)
1147
                for seq, quant_state in quant_states.items():
1148
1149
                    num_elements[seq] = (math.prod(quant_state.shape) //
                                         pack_ratio)
1150
1151
1152
1153

                offsets = np.concatenate(([0], np.cumsum(num_elements)))
                set_weight_attrs(param, {"bnb_shard_offsets": offsets})

1154
1155
1156
1157
                if load_8bit:
                    set_weight_attrs(
                        param, {"matmul_state": [None] * len(quant_states)})

1158
1159
1160
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model, model_config.revision)

1161
1162
1163
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1164
1165
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1166
                model = _initialize_model(vllm_config=vllm_config)
1167
1168
1169
1170
1171
1172

                self._load_weights(model_config, model)

        return model.eval()


1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
class GGUFModelLoader(BaseModelLoader):
    """
    Model loader that can load GGUF files. This is useful for loading models
    that are quantized with GGUF and saved in the GGUF format. This loader
    supports loading both full models and sharded models.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _prepare_weights(self, model_name_or_path: str):
        if os.path.isfile(model_name_or_path):
            return model_name_or_path
        else:
            raise ValueError(f"{model_name_or_path} is not a file.")

    def _get_gguf_weights_map(self, model_config: ModelConfig):
        """
        GGUF uses this naming convention for their tensors from HF checkpoint:
        `blk.N.BB.weight` and `blk.N.BB.bias`
        where N signifies the block number of a layer, and BB signifies the
        attention/mlp layer components.
        See "Standardized tensor names" in
        https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
        """
        config = model_config.hf_config
        model_type = config.model_type
        # hack: ggufs have a different name than transformers
        if model_type == "cohere":
            model_type = "command-r"
        arch = None
        for key, value in gguf.MODEL_ARCH_NAMES.items():
            if value == model_type:
                arch = key
                break
        if arch is None:
            raise RuntimeError(f"Unknown gguf model_type: {model_type}")
        num_layers = config.num_hidden_layers
        name_map = gguf.get_tensor_name_map(arch, num_layers)
        with torch.device("meta"):
            dummy_model = AutoModelForCausalLM.from_config(config)
        state_dict = dummy_model.state_dict()

        gguf_to_hf_name_map = {}
        for hf_name in state_dict:
            name, suffix = hf_name.rsplit(".", 1)
            gguf_name = name_map.get_name(name)
            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
        return gguf_to_hf_name_map

    def _get_weights_iterator(
        self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        return gguf_quant_weights_iterator(model_name_or_path,
                                           gguf_to_hf_name_map)

1232
1233
1234
    def download_model(self, model_config: ModelConfig) -> None:
        self._prepare_weights(model_config.model)

1235
1236
1237
    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config
1238
1239
1240
1241
1242
1243
1244
1245
1246
        local_model_path = self._prepare_weights(model_config.model)
        gguf_weights_map = self._get_gguf_weights_map(model_config)
        # we can only know if tie word embeddings after mapping weights
        if "lm_head.weight" in get_gguf_extra_tensor_names(
                local_model_path, gguf_weights_map):
            model_config.hf_config.update({"tie_word_embeddings": True})

        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
1247
                model = _initialize_model(vllm_config=vllm_config)
1248
1249
1250
1251
1252
            model.load_weights(
                self._get_weights_iterator(local_model_path, gguf_weights_map))
        return model


1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
class RunaiModelStreamerLoader(BaseModelLoader):
    """
        Model loader that can load safetensors 
        files from local FS or S3 bucket.
    """

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            extra_config = load_config.model_loader_extra_config

            if ("concurrency" in extra_config
                    and isinstance(extra_config.get("concurrency"), int)):
                os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
                    extra_config.get("concurrency"))

            if ("memory_limit" in extra_config
                    and isinstance(extra_config.get("memory_limit"), int)):
                os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
                    extra_config.get("memory_limit"))

            runai_streamer_s3_endpoint = os.getenv(
                'RUNAI_STREAMER_S3_ENDPOINT')
            aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
            if (runai_streamer_s3_endpoint is None
                    and aws_endpoint_url is not None):
                os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> List[str]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        is_s3_path = is_s3(model_name_or_path)
        is_local = os.path.isdir(model_name_or_path)
        safetensors_pattern = "*.safetensors"
        index_file = SAFE_WEIGHTS_INDEX_NAME

        hf_folder = (model_name_or_path if
                     (is_local or is_s3_path) else download_weights_from_hf(
                         model_name_or_path,
                         self.load_config.download_dir,
                         [safetensors_pattern],
                         revision,
                         ignore_patterns=self.load_config.ignore_patterns,
                     ))

        if is_s3_path:
            hf_weights_files = s3_glob(path=hf_folder,
                                       allow_pattern=[safetensors_pattern])
        else:
            hf_weights_files = glob.glob(
                os.path.join(hf_folder, safetensors_pattern))

        if not is_local and not is_s3_path:
            download_safetensors_index_file_from_hf(
                model_name_or_path, index_file, self.load_config.download_dir,
                revision)

        if not hf_weights_files:
            raise RuntimeError(
                f"Cannot find any safetensors model weights with "
                f"`{model_name_or_path}`")

        return hf_weights_files

    def _get_weights_iterator(
            self, model_or_path: str,
            revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_weights_files = self._prepare_weights(model_or_path, revision)
        return runai_safetensors_weights_iterator(hf_weights_files)

    def download_model(self, model_config: ModelConfig) -> None:
        """Download model if necessary"""
        self._prepare_weights(model_config.model, model_config.revision)

    def load_model(self, vllm_config: VllmConfig) -> nn.Module:
        """Perform streaming of the model to destination"""
        device_config = vllm_config.device_config
        model_config = vllm_config.model_config

        target_device = torch.device(device_config.device)
        with set_default_torch_dtype(model_config.dtype):
            with target_device:
                model = _initialize_model(vllm_config=vllm_config)

            model_weights = model_config.model
            if hasattr(model_config, "model_weights"):
                model_weights = model_config.model_weights
            model.load_weights(
                self._get_weights_iterator(model_weights,
                                           model_config.revision))

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    with device_loading_context(module, target_device):
                        quant_method.process_weights_after_loading(module)
        return model.eval()


1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

1367
1368
1369
    if load_config.load_format == LoadFormat.SHARDED_STATE:
        return ShardedStateLoader(load_config)

1370
1371
1372
    if load_config.load_format == LoadFormat.BITSANDBYTES:
        return BitsAndBytesModelLoader(load_config)

1373
1374
1375
    if load_config.load_format == LoadFormat.GGUF:
        return GGUFModelLoader(load_config)

1376
1377
1378
    if load_config.load_format == LoadFormat.RUNAI_STREAMER:
        return RunaiModelStreamerLoader(load_config)

1379
    return DefaultModelLoader(load_config)