loader.py 36 KB
Newer Older
1
# ruff: noqa: SIM117
2
import collections
3
import copy
4
import fnmatch
5
import glob
6
7
import json
import math
8
9
import os
from abc import ABC, abstractmethod
10
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
11

12
import huggingface_hub
13
import numpy as np
14
import torch
15
from huggingface_hub import HfApi, hf_hub_download
16
17
from torch import nn

18
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
19
20
                         LoRAConfig, ModelConfig, MultiModalConfig,
                         ParallelConfig, SchedulerConfig)
21
from vllm.envs import VLLM_USE_MODELSCOPE
22
from vllm.logger import init_logger
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.model_loader.tensorizer import (
26
    TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
27
    serialize_vllm_model, tensorizer_weights_iterator)
28
29
30
from vllm.model_executor.model_loader.utils import (get_model_architecture,
                                                    set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
31
32
    download_safetensors_index_file_from_hf, download_weights_from_hf,
    filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
33
34
    get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
    pt_weights_iterator, safetensors_weights_iterator)
35
36
from vllm.model_executor.models.interfaces import (has_inner_state,
                                                   supports_lora,
37
                                                   supports_vision)
38
from vllm.model_executor.utils import set_weight_attrs
39
40
from vllm.platforms import current_platform
from vllm.utils import is_tpu
41
42
43
44

logger = init_logger(__name__)


45
def _get_quantization_config(
46
        model_config: ModelConfig,
47
48
        load_config: LoadConfig) -> Optional[QuantizationConfig]:
    """Get the quantization config."""
49
50
    if model_config.quantization is not None:
        quant_config = get_quant_config(model_config, load_config)
51
        capability = current_platform.get_device_capability()
52
53
54
55
56
57
58
59
60
61
62
63
64
        capability = capability[0] * 10 + capability[1]
        if capability < quant_config.get_min_capability():
            raise ValueError(
                f"The quantization method {model_config.quantization} is not "
                "supported for the current GPU. "
                f"Minimum capability: {quant_config.get_min_capability()}. "
                f"Current capability: {capability}.")
        supported_dtypes = quant_config.get_supported_act_dtypes()
        if model_config.dtype not in supported_dtypes:
            raise ValueError(
                f"{model_config.dtype} is not supported for quantization "
                f"method {model_config.quantization}. Supported dtypes: "
                f"{supported_dtypes}")
65
66
        return quant_config
    return None
67
68
69


def _get_model_initialization_kwargs(
70
71
72
73
        model_class: Type[nn.Module],
        lora_config: Optional[LoRAConfig],
        multimodal_config: Optional[MultiModalConfig],
        scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
74
    """Get extra kwargs for model initialization."""
75
    extra_kwargs: Dict[str, Any] = {}
76
77
78

    if supports_lora(model_class):
        # lora_config=None is used to disable LoRA
79
80
81
82
83
84
85
        extra_kwargs["lora_config"] = lora_config
    elif lora_config:
        raise ValueError(
            f"Model {model_class.__name__} does not support LoRA, "
            "but LoRA is enabled. Support for this model may "
            "be added in the future. If this is important to you, "
            "please open an issue on github.")
86
87

    if supports_vision(model_class):
88
        if multimodal_config is None:
89
90
            raise ValueError("Provide vision related configurations "
                             "through LLM entrypoint or engine arguments.")
91

92
        extra_kwargs["multimodal_config"] = multimodal_config
93

94
95
96
    if has_inner_state(model_class) and scheduler_config:
        extra_kwargs["scheduler_config"] = scheduler_config

97
98
99
    return extra_kwargs


100
101
102
103
104
105
106
def _initialize_model(
        model_config: ModelConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        multimodal_config: Optional[MultiModalConfig],
        cache_config: CacheConfig,
        scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
107
108
    """Initialize a model with the given configurations."""
    model_class = get_model_architecture(model_config)[0]
109
    quant_config = _get_quantization_config(model_config, load_config)
110
111

    return model_class(config=model_config.hf_config,
112
                       cache_config=cache_config,
113
                       quant_config=quant_config,
114
                       **_get_model_initialization_kwargs(
115
116
                           model_class, lora_config, multimodal_config,
                           scheduler_config))
117
118
119
120
121
122
123
124
125
126
127
128


class BaseModelLoader(ABC):
    """Base class for model loaders."""

    def __init__(self, load_config: LoadConfig):
        self.load_config = load_config

    @abstractmethod
    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
129
                   multimodal_config: Optional[MultiModalConfig],
130
                   parallel_config: ParallelConfig,
131
132
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        """Load a model with the given configurations."""
        ...


class DefaultModelLoader(BaseModelLoader):
    """Model loader that can load different file types from disk."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def _maybe_download_from_modelscope(
            self, model: str, revision: Optional[str]) -> Optional[str]:
        """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
149

150
151
152
153
154
155
156
157
158
159
160
161
        Returns the path to the downloaded model, or None if the model is not
        downloaded from ModelScope."""
        if VLLM_USE_MODELSCOPE:
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            # pylint: disable=C.
            from modelscope.hub.snapshot_download import snapshot_download

            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=self.load_config.download_dir,
162
163
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
164
                    ignore_patterns=self.load_config.ignore_patterns,
165
                )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            else:
                model_path = model
            return model_path
        return None

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str],
                         fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
        """Prepare weights for the model.

        If the model is not local, it will be downloaded."""
        model_name_or_path = self._maybe_download_from_modelscope(
            model_name_or_path, revision) or model_name_or_path

        is_local = os.path.isdir(model_name_or_path)
        load_format = self.load_config.load_format
        use_safetensors = False
        # Some quantized models use .pt files for storing the weights.
        if load_format == LoadFormat.AUTO:
            allow_patterns = ["*.safetensors", "*.bin"]
        elif load_format == LoadFormat.SAFETENSORS:
            use_safetensors = True
            allow_patterns = ["*.safetensors"]
        elif load_format == LoadFormat.PT:
            allow_patterns = ["*.pt"]
        elif load_format == LoadFormat.NPCACHE:
            allow_patterns = ["*.bin"]
        else:
            raise ValueError(f"Unknown load_format: {load_format}")

        if fall_back_to_pt:
            allow_patterns += ["*.pt"]

        if not is_local:
200
201
202
203
204
205
206
            hf_folder = download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
207
208
209
210
211
212
213
214
215
216
217
        else:
            hf_folder = model_name_or_path

        hf_weights_files: List[str] = []
        for pattern in allow_patterns:
            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
            if len(hf_weights_files) > 0:
                if pattern == "*.safetensors":
                    use_safetensors = True
                break

218
219
220
221
222
223
224
225
226
227
228
229
230
        if use_safetensors:
            # For models like Mistral-7B-Instruct-v0.3
            # there are both sharded safetensors files and a consolidated
            # safetensors file. Using both breaks.
            # Here, we download the `model.safetensors.index.json` and filter
            # any files not found in the index.
            if not is_local:
                download_safetensors_index_file_from_hf(
                    model_name_or_path, self.load_config.download_dir,
                    revision)
            hf_weights_files = filter_duplicate_safetensors_files(
                hf_weights_files, hf_folder)
        else:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_folder, hf_weights_files, use_safetensors

    def _get_weights_iterator(
        self, model_name_or_path: str, revision: Optional[str],
        fall_back_to_pt: bool
    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
        """Get an iterator for the model weights based on the load format."""
        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision, fall_back_to_pt)
        if self.load_config.load_format == LoadFormat.NPCACHE:
            # Currently np_cache only support *.bin checkpoints
            assert use_safetensors is False
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            weights_iterator = np_cache_weights_iterator(
                model_name_or_path, self.load_config.download_dir, hf_folder,
                hf_weights_files)
        elif use_safetensors:
            weights_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weights_iterator = pt_weights_iterator(hf_weights_files)

        if is_tpu():
            # In PyTorch XLA, we should call `xm.mark_step` frequently so that
            # not too many ops are accumulated in the XLA program.
            import torch_xla.core.xla_model as xm

            def _xla_weights_iterator(iterator: Generator):
                for weights in iterator:
                    yield weights
                    xm.mark_step()

            weights_iterator = _xla_weights_iterator(weights_iterator)
        return weights_iterator
270
271
272
273

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
274
                   multimodal_config: Optional[MultiModalConfig],
275
                   parallel_config: ParallelConfig,
276
277
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
278
279
280
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
281
                                          lora_config, multimodal_config,
282
                                          cache_config, scheduler_config)
283
284
285
286
287
288
289
            model.load_weights(
                self._get_weights_iterator(model_config.model,
                                           model_config.revision,
                                           fall_back_to_pt=getattr(
                                               model,
                                               "fall_back_to_pt_during_load",
                                               True)), )
290

291
            for _, module in model.named_modules():
292
293
294
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        return model.eval()


class DummyModelLoader(BaseModelLoader):
    """Model loader that will set model weights to random values."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if load_config.model_loader_extra_config:
            raise ValueError(f"Model loader extra config is not supported for "
                             f"load format {load_config.load_format}")

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
310
                   multimodal_config: Optional[MultiModalConfig],
311
                   parallel_config: ParallelConfig,
312
313
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
314
315
316
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
317
                                          lora_config, multimodal_config,
318
                                          cache_config, scheduler_config)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        return model.eval()


class TensorizerLoader(BaseModelLoader):
    """Model loader using CoreWeave's tensorizer library."""

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
            self.tensorizer_config = load_config.model_loader_extra_config
        else:
            self.tensorizer_config = TensorizerConfig(
                **load_config.model_loader_extra_config)

    def _verify_config(self, model_config: ModelConfig,
                       parallel_config: ParallelConfig):
        self.tensorizer_config.verify_with_model_config(model_config)
        self.tensorizer_config.verify_with_parallel_config(parallel_config)

    def _get_weights_iterator(
            self) -> Generator[Tuple[str, torch.Tensor], None, None]:
        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
        return tensorizer_weights_iterator(tensorizer_args)

346
    def _load_model_serialized_cpu(
347
348
349
350
        self,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
351
        multimodal_config: Optional[MultiModalConfig],
352
        cache_config: CacheConfig,
353
    ) -> nn.Module:
354
        """Load a serialized model with tensorizer to the CPU.
355

356
357
358
359
        This is only necessary when the model isn't vLLM-tensorized (see
        examples/tensorize_vllm_model.py) This should still be faster than
        default HuggingFace loading, but will be slower than loading a
        vLLM-tensorized model.
360
361
362
363
        """
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
364
                                          lora_config, multimodal_config,
365
                                          cache_config)
366
367
368
369
370

            model.load_weights(self._get_weights_iterator())
        return model.eval()

    def _load_model_serialized(
371
372
373
374
        self,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
375
        multimodal_config: Optional[MultiModalConfig],
376
        cache_config: CacheConfig,
377
378
379
    ) -> nn.Module:
        """Load a serialized model with tensorizer.

380
381
382
        Expects a vLLM-tensorized model. See the
        examples/tensorize_vllm_model.py example script
        for serializing vLLM models."""
383
384
385
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model_class = get_model_architecture(model_config)[0]
386
387
                quant_config = _get_quantization_config(
                    model_config, self.load_config)
388
                extra_kwargs = _get_model_initialization_kwargs(
389
                    model_class, lora_config, multimodal_config)
390
                extra_kwargs["quant_config"] = quant_config
391
                extra_kwargs["cache_config"] = cache_config
392
393
394
395
396
397
398
399
400
401
402
403

                tensorizer_config = copy.copy(self.tensorizer_config)
                tensorizer_config.model_class = model_class
                tensorizer_config.hf_config = model_config.hf_config
                tensorizer_config.dtype = model_config.dtype

                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
        return model.eval()

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
404
                   multimodal_config: Optional[MultiModalConfig],
405
                   parallel_config: ParallelConfig,
406
407
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
408
409
        self._verify_config(model_config, parallel_config)

410
411
412
413
414
415
        if parallel_config.tensor_parallel_size > 1:
            from vllm.distributed import get_tensor_model_parallel_rank
            self.tensorizer_config.tensorizer_uri = \
                self.tensorizer_config.tensorizer_uri \
                    % get_tensor_model_parallel_rank()

416
        if is_vllm_tensorized(self.tensorizer_config):
417
            return self._load_model_serialized(model_config, device_config,
418
                                               lora_config, multimodal_config,
419
                                               cache_config)
420
        return self._load_model_serialized_cpu(model_config, device_config,
421
                                               lora_config, multimodal_config,
422
                                               cache_config)
423

424
425
426
427
428
429
430
431
432
433
    @staticmethod
    def save_model(
        model: torch.nn.Module,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        serialize_vllm_model(
            model=model,
            tensorizer_config=tensorizer_config,
        )

434

435
436
437
438
439
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint. See
440
    `examples/save_sharded_state.py` for creating a sharded checkpoint.
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    """

    DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)
        extra_config = ({} if load_config.model_loader_extra_config is None
                        else load_config.model_loader_extra_config.copy())
        self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
        if extra_config:
            raise ValueError(f"Unexpected extra config keys for load format "
                             f"{load_config.load_format}: "
                             f"{load_config.model_loader_extra_config.keys()}")

    @staticmethod
    def _filter_subtensors(
            tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Filter out all tensors that share the same memory or a subset of the
        memory of another tensor.
        """
462
463
        same_storage_groups: Dict[Any, List[Tuple[
            str, torch.Tensor]]] = collections.defaultdict(list)
464
465
466
467
468
469
470
471
        for key, tensor in tensors.items():
            if tensor.numel():
                ptr = tensor.untyped_storage().data_ptr()
                same_storage_groups[tensor.device, ptr].append((key, tensor))

        def get_end_ptr(tensor: torch.Tensor) -> int:
            return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

472
        result: Dict[str, torch.Tensor] = {}
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        for group in same_storage_groups.values():
            for k, t in group:
                a, b = t.data_ptr(), get_end_ptr(t)
                for k2, t2 in group:
                    if not t2.is_contiguous():
                        continue
                    a2, b2 = t2.data_ptr(), get_end_ptr(t2)
                    if a < a2 or b2 < b:
                        continue
                    if a2 < a or b < b2 or not t.is_contiguous():
                        break  # t2 covers strictly more memory than t.
                    if k2 < k:
                        # Same tensors, keep the one with the smaller key.
                        break
                else:
                    result[k] = t
        return result

491
492
493
494
495
496
    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]):
        if os.path.isdir(model_name_or_path):
            return model_name_or_path
        else:
            allow_patterns = ["*.safetensors"]
497
498
499
500
501
502
503
            return download_weights_from_hf(
                model_name_or_path,
                self.load_config.download_dir,
                allow_patterns,
                revision,
                ignore_patterns=self.load_config.ignore_patterns,
            )
504

505
506
507
    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
508
                   multimodal_config: Optional[MultiModalConfig],
509
510
511
512
513
514
                   parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
        from safetensors.torch import safe_open

        from vllm.distributed import get_tensor_model_parallel_rank
515
516
517
518

        local_model_path = self._prepare_weights(model_config.model,
                                                 model_config.revision)

519
520
521
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
522
                                          lora_config, multimodal_config,
523
524
525
                                          cache_config)
            rank = get_tensor_model_parallel_rank()
            pattern = os.path.join(
526
                local_model_path,
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
                self.pattern.format(rank=rank, part="*"),
            )
            filepaths = glob.glob(pattern)
            if not filepaths:
                # TODO: support un-sharded checkpoints too
                raise ValueError(
                    f"Could not find checkpoint files '{pattern}', only "
                    f"pre-sharded checkpoints are currently supported!")
            state_dict = self._filter_subtensors(model.state_dict())
            for path in filepaths:
                with safe_open(path, framework="pt") as f:
                    for key in f.keys():  # noqa: SIM118
                        tensor = f.get_tensor(key)
                        # If loading with LoRA enabled, additional padding may
                        # be added to certain parameters. We only load into a
                        # narrowed view of the parameter data.
                        param_data = state_dict[key].data
                        param_shape = state_dict[key].shape
                        for dim, size in enumerate(tensor.shape):
                            if size < param_shape[dim]:
                                param_data = param_data.narrow(dim, 0, size)
                        if tensor.shape != param_shape:
                            logger.warning(
                                "loading tensor of shape %s into "
                                "parameter '%s' of shape %s", tensor.shape,
                                key, param_shape)
                        param_data.copy_(tensor)
                        state_dict.pop(key)
            if state_dict:
                raise ValueError(
                    f"Missing keys {tuple(state_dict)} in loaded state!")
        return model.eval()

    @staticmethod
    def save_model(
        model: torch.nn.Module,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from safetensors.torch import save_file

        from vllm.distributed import get_tensor_model_parallel_rank
        if pattern is None:
            pattern = ShardedStateLoader.DEFAULT_PATTERN
        rank = get_tensor_model_parallel_rank()
        part_idx = 0
        total_size = 0
        state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
        state_dict_part: Dict[str, torch.Tensor] = {}
        for key, tensor in state_dict.items():
            param_size = tensor.nelement() * tensor.element_size()
            if max_size is not None and total_size + param_size > max_size:
                filename = pattern.format(rank=rank, part=part_idx)
                save_file(
                    state_dict_part,
                    os.path.join(path, filename),
                )
                part_idx += 1
                total_size = 0
                state_dict_part = {}
            state_dict_part[key] = tensor
            total_size += param_size
        if len(state_dict_part) > 0:
            filename = pattern.format(rank=rank, part=part_idx)
            save_file(
                state_dict_part,
                os.path.join(path, filename),
            )


598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
class BitsAndBytesModelLoader(BaseModelLoader):
    """Model loader to load model weights with BitAndBytes quantization."""

    default_target_modules = [
        "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
        "o_proj"
    ]

    possible_config_file_names = ["adapter_config.json"]

    def __init__(self, load_config: LoadConfig):
        super().__init__(load_config)

        # we don't need to quantize the whole model, only the target modules
        # that are specified in the adapter config file. If the adapter config
        # file is not provided, we will quantize the default modules.
        if (not load_config.model_loader_extra_config
                or "qlora_adapter_name_or_path"
                not in load_config.model_loader_extra_config):
            self.target_modules = self.default_target_modules
            return

        qlora_adapter = load_config.model_loader_extra_config[
            "qlora_adapter_name_or_path"]

        config_file_path = self._get_config_file(qlora_adapter)

        with open(config_file_path, "r") as f:
            config = json.load(f)
            self.target_modules = config["target_modules"]

    def _get_config_file(self, qlora_adapter: str) -> str:
        is_local = os.path.isdir(qlora_adapter)
        config_file_path = None
        if is_local:
            for file in self.possible_config_file_names:
                config_file_path = os.path.join(qlora_adapter, file)
                if os.path.exists(config_file_path):
                    break
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
            for file in self.possible_config_file_names:
                if file in repo_files:
                    config_file_path = hf_hub_download(repo_id=qlora_adapter,
                                                       filename=file)
                    break

        if not config_file_path:
            raise ValueError(
                f"Cannot find adapter config file in {qlora_adapter}")

        return config_file_path

    def _get_weight_files(
            self,
            model_name_or_path: str,
            allowed_patterns: List[str],
            revision: Optional[str] = None) -> Tuple[List[str], str]:
        """Retrieve weight files. Download the files if necessary. 
        
        Return the weight files and the file pattern."""
        is_local = os.path.isdir(model_name_or_path)

        if is_local:
            for pattern in allowed_patterns:
                weight_files = glob.glob(
                    os.path.join(model_name_or_path, pattern))
                if weight_files:
                    return weight_files, pattern
        else:
            hf_api = HfApi()
            repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
            for pattern in allowed_patterns:
                matching_files = fnmatch.filter(repo_files, pattern)
                if matching_files:
                    hf_folder = download_weights_from_hf(
675
676
677
678
679
680
                        model_name_or_path,
                        self.load_config.download_dir,
                        [pattern],
                        revision,
                        ignore_patterns=self.load_config.ignore_patterns,
                    )
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
                    return glob.glob(os.path.join(hf_folder, pattern)), pattern

        raise RuntimeError(
            f"No model weights found in: `{model_name_or_path}`")

    def _prepare_weights(self, model_name_or_path: str,
                         revision: Optional[str]) -> Tuple[List[str], bool]:
        """Prepare weight files for the model."""

        allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]

        hf_weights_files, matched_pattern = self._get_weight_files(
            model_name_or_path, allowed_patterns, revision)

        if matched_pattern != "*.safetensors":
            hf_weights_files = filter_files_not_needed_for_inference(
                hf_weights_files)

        if len(hf_weights_files) == 0:
            raise RuntimeError(
                f"Cannot find any model weights with `{model_name_or_path}`")

        return hf_weights_files, matched_pattern == "*.safetensors"

    def _get_quantized_weights_iterator(
        self, model_name_or_path: str, revision: Optional[str]
    ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
                                                                     Any]]:
        """Get an iterator to the model weights with bitsandbytes quantization,
        as well as the quantization state dictionary."""

        # only load the bitsandbytes module when needed
        try:
            import bitsandbytes
            if bitsandbytes.__version__ < "0.42.0":
                raise ImportError("bitsandbytes version is wrong. Please "
                                  "install bitsandbytes>=0.42.0.")
            from bitsandbytes.functional import quantize_4bit
        except ImportError as err:
            raise ImportError("Please install bitsandbytes>=0.42.0 via "
                              "`pip install bitsandbytes>=0.42.0` to use "
                              "bitsandbytes quantizer.") from err

        hf_weights_files, use_safetensors = self._prepare_weights(
            model_name_or_path, revision)

        quant_state_dict = {}
        if use_safetensors:
            weight_iterator = safetensors_weights_iterator(hf_weights_files)
        else:
            weight_iterator = pt_weights_iterator(hf_weights_files)

        def generator():
            for weight_name, weight_tensor in weight_iterator:
                if any(target_module in weight_name
                       for target_module in self.target_modules):
                    weight_name = weight_name.replace(".weight", ".qweight")
                    #  bitsandbytes requires data in GPU
                    loaded_weight = weight_tensor.cuda().data
                    with set_default_torch_dtype(torch.float32):
                        processed_weight, quant_state = quantize_4bit(
                            loaded_weight,
                            compress_statistics=True,
                            quant_type="nf4")

                    quant_state_dict[weight_name] = quant_state
                else:
                    processed_weight = weight_tensor

                yield weight_name, processed_weight

        return generator(), quant_state_dict

    def _load_weights(self, model_config: ModelConfig,
                      model: nn.Module) -> None:
        if not hasattr(model, 'load_weights'):
            raise AttributeError(
                "The required method 'load_weights' is not defined in class"
                f" {type(self).__name__}.")

        if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
            raise AttributeError(
                f"Model {type(self).__name__} does not support BitsAndBytes "
                "quantization yet.")

        logger.info("Loading weights with BitsAndBytes quantization. "
                    " May take a while ...")

        qweight_iterator, quant_state_dict = (
            self._get_quantized_weights_iterator(model_config.model,
                                                 model_config.revision))

        model.load_weights(qweight_iterator)

        param_dict = dict(model.named_parameters())
        stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
        for quant_param_name in quant_state_dict:
            non_stacked_param_name = quant_param_name

            shard_index = 0
            for shard_name, (
                    weight_name, index
            ) in model.bitsandbytes_stacked_params_mapping.items():
                if shard_name in quant_param_name:
                    shard_index = index
                    quant_param_name = quant_param_name.replace(
                        shard_name, weight_name)
                    break

            if quant_param_name not in param_dict:
                raise ValueError(
                    f"Parameter {quant_param_name} not found in the model.")

            if quant_param_name not in stacked_quant_state_dict:
                stacked_quant_state_dict[quant_param_name] = {}

            stacked_quant_state_dict[quant_param_name][shard_index] = (
                quant_state_dict[non_stacked_param_name])

        # save quant_states and offsets as the attributes of the parameters
        for param_name, param in param_dict.items():
            if param_name in stacked_quant_state_dict:
                quant_states = stacked_quant_state_dict[param_name]
                set_weight_attrs(param, {"bnb_quant_state": quant_states})

                pack_ratio = getattr(param, "pack_factor", -1)
                if pack_ratio == -1:
                    raise ValueError(
                        f"pack_factor not set for parameter {param_name}.")

                num_elements = [0] * len(quant_states)
                for seq, quant_state in enumerate(quant_states.items()):
                    num_elements[seq] = math.prod(
                        quant_state[1].shape) // pack_ratio

                offsets = np.concatenate(([0], np.cumsum(num_elements)))
                set_weight_attrs(param, {"bnb_shard_offsets": offsets})

    def load_model(self, *, model_config: ModelConfig,
                   device_config: DeviceConfig,
                   lora_config: Optional[LoRAConfig],
822
                   multimodal_config: Optional[MultiModalConfig],
823
824
825
826
827
828
                   parallel_config: ParallelConfig,
                   scheduler_config: SchedulerConfig,
                   cache_config: CacheConfig) -> nn.Module:
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
                model = _initialize_model(model_config, self.load_config,
829
                                          lora_config, multimodal_config,
830
831
832
833
834
835
836
                                          cache_config)

                self._load_weights(model_config, model)

        return model.eval()


837
838
839
840
841
842
843
844
845
846
847
848
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
    """Get a model loader based on the load format."""

    if isinstance(load_config.load_format, type):
        return load_config.load_format(load_config)

    if load_config.load_format == LoadFormat.DUMMY:
        return DummyModelLoader(load_config)

    if load_config.load_format == LoadFormat.TENSORIZER:
        return TensorizerLoader(load_config)

849
850
851
    if load_config.load_format == LoadFormat.SHARDED_STATE:
        return ShardedStateLoader(load_config)

852
853
854
    if load_config.load_format == LoadFormat.BITSANDBYTES:
        return BitsAndBytesModelLoader(load_config)

855
    return DefaultModelLoader(load_config)