model_loading_utils.py 29.6 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2025 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import functools
18
import importlib
19
20
import inspect
import os
21
from array import array
22
from collections import OrderedDict, defaultdict
23
from concurrent.futures import ThreadPoolExecutor, as_completed
24
from pathlib import Path
25
26
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
27
28
29

import safetensors
import torch
Marc Sun's avatar
Marc Sun committed
30
from huggingface_hub import DDUFEntry
31
from huggingface_hub.utils import EntryNotFoundError
32

33
from ..quantizers import DiffusersQuantizer
34
from ..utils import (
35
    DEFAULT_HF_PARALLEL_LOADING_WORKERS,
36
    GGUF_FILE_EXTENSION,
37
    SAFE_WEIGHTS_INDEX_NAME,
38
    SAFETENSORS_FILE_EXTENSION,
39
40
41
    WEIGHTS_INDEX_NAME,
    _add_variant,
    _get_model_file,
42
    deprecate,
43
    is_accelerate_available,
44
    is_accelerate_version,
45
46
    is_gguf_available,
    is_torch_available,
47
48
49
50
51
52
53
    is_torch_version,
    logging,
)


logger = logging.get_logger(__name__)

54
55
56
57
58
59
60
_CLASS_REMAPPING_DICT = {
    "Transformer2DModel": {
        "ada_norm_zero": "DiTTransformer2DModel",
        "ada_norm_single": "PixArtTransformer2DModel",
    }
}

61
62
63

if is_accelerate_available():
    from accelerate import infer_auto_device_map
64
    from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
65
66
67


# Adapted from `transformers` (see modeling_utils.py)
68
69
70
def _determine_device_map(
    model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
):
71
    if isinstance(device_map, str):
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        special_dtypes = {}
        if hf_quantizer is not None:
            special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
        special_dtypes.update(
            {
                name: torch.float32
                for name, _ in model.named_parameters()
                if any(m in name for m in keep_in_fp32_modules)
            }
        )

        target_dtype = torch_dtype
        if hf_quantizer is not None:
            target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)

87
88
89
        no_split_modules = model._get_no_split_modules(device_map)
        device_map_kwargs = {"no_split_module_classes": no_split_modules}

90
91
92
93
94
95
96
97
        if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
            device_map_kwargs["special_dtypes"] = special_dtypes
        elif len(special_dtypes) > 0:
            logger.warning(
                "This model has some weights that should be kept in higher precision, you need to upgrade "
                "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
            )

98
99
100
101
102
103
104
105
106
107
108
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                model,
                dtype=torch_dtype,
                low_zero=(device_map == "balanced_low_0"),
                max_memory=max_memory,
                **device_map_kwargs,
            )
        else:
            max_memory = get_max_memory(max_memory)

109
110
111
        if hf_quantizer is not None:
            max_memory = hf_quantizer.adjust_max_memory(max_memory)

112
        device_map_kwargs["max_memory"] = max_memory
113
114
115
116
        device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

        if hf_quantizer is not None:
            hf_quantizer.validate_environment(device_map=device_map)
117
118
119
120

    return device_map


121
122
def _fetch_remapped_cls_from_config(config, old_class):
    previous_class_name = old_class.__name__
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)

    # Details:
    # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
    if remapped_class_name:
        # load diffusers library to import compatible and original scheduler
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        remapped_class = getattr(diffusers_library, remapped_class_name)
        logger.info(
            f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
            f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
            " DOESN'T affect the final results."
        )
        return remapped_class
    else:
        return old_class
139
140


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
    """
    Find the device of param_name from the device_map.
    """
    if device_map is None:
        return "cpu"
    else:
        module_name = param_name
        # find next higher level module that is defined in device_map:
        # bert.lm_head.weight -> bert.lm_head -> bert -> ''
        while len(module_name) > 0 and module_name not in device_map:
            module_name = ".".join(module_name.split(".")[:-1])
        if module_name == "" and "" not in device_map:
            raise ValueError(f"{param_name} doesn't have any device set.")
        return device_map[module_name]


158
def load_state_dict(
Marc Sun's avatar
Marc Sun committed
159
160
161
    checkpoint_file: Union[str, os.PathLike],
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
    disable_mmap: bool = False,
162
    map_location: Union[str, torch.device] = "cpu",
163
):
164
165
166
    """
    Reads a checkpoint file, returning properly formatted errors if they arise.
    """
167
    # TODO: maybe refactor a bit this part where we pass a dict here
168
169
    if isinstance(checkpoint_file, dict):
        return checkpoint_file
170
171
172
    try:
        file_extension = os.path.basename(checkpoint_file).split(".")[-1]
        if file_extension == SAFETENSORS_FILE_EXTENSION:
Marc Sun's avatar
Marc Sun committed
173
174
175
176
            if dduf_entries:
                # tensors are loaded on cpu
                with dduf_entries[checkpoint_file].as_mmap() as mm:
                    return safetensors.torch.load(mm)
177
178
179
            if disable_mmap:
                return safetensors.torch.load(open(checkpoint_file, "rb").read())
            else:
180
                return safetensors.torch.load_file(checkpoint_file, device=map_location)
181
182
        elif file_extension == GGUF_FILE_EXTENSION:
            return load_gguf_checkpoint(checkpoint_file)
183
        else:
184
            extra_args = {}
185
            weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
186
187
188
189
190
191
192
193
194
195
            # mmap can only be used with files serialized with zipfile-based format.
            if (
                isinstance(checkpoint_file, str)
                and map_location != "meta"
                and is_torch_version(">=", "2.1.0")
                and is_zipfile(checkpoint_file)
                and not disable_mmap
            ):
                extra_args = {"mmap": True}
            return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    except Exception as e:
        try:
            with open(checkpoint_file) as f:
                if f.read().startswith("version"):
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    raise ValueError(
                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
                        "model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            raise OSError(
212
                f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
213
214
215
216
217
218
219
220
            )


def load_model_dict_into_meta(
    model,
    state_dict: OrderedDict,
    dtype: Optional[Union[str, torch.dtype]] = None,
    model_name_or_path: Optional[str] = None,
221
222
223
224
225
226
227
228
    hf_quantizer: Optional[DiffusersQuantizer] = None,
    keep_in_fp32_modules: Optional[List] = None,
    device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
    unexpected_keys: Optional[List[str]] = None,
    offload_folder: Optional[Union[str, os.PathLike]] = None,
    offload_index: Optional[Dict] = None,
    state_dict_index: Optional[Dict] = None,
    state_dict_folder: Optional[Union[str, os.PathLike]] = None,
229
) -> List[str]:
230
231
232
233
    """
    This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
    params on a `meta` device. It replaces the model params with the data from the `state_dict`
    """
234

235
    is_quantized = hf_quantizer is not None
236
    empty_state_dict = model.state_dict()
237

238
239
240
241
    for param_name, param in state_dict.items():
        if param_name not in empty_state_dict:
            continue

242
243
244
245
        set_module_kwargs = {}
        # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
        # in int/uint/bool and not cast them.
        # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
246
247
248
        if dtype is not None and torch.is_floating_point(param):
            if keep_in_fp32_modules is not None and any(
                module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
249
250
            ):
                param = param.to(torch.float32)
251
                set_module_kwargs["dtype"] = torch.float32
252
253
254
            # For quantizers have save weights using torch.float8_e4m3fn
            elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
                pass
255
256
            else:
                param = param.to(dtype)
257
258
                set_module_kwargs["dtype"] = dtype

259
260
261
262
        if is_accelerate_version(">", "1.8.1"):
            set_module_kwargs["non_blocking"] = True
            set_module_kwargs["clear_cache"] = False

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
        # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
        # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
        old_param = model
        splits = param_name.split(".")
        for split in splits:
            old_param = getattr(old_param, split)

        if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
            old_param = None

        if old_param is not None:
            if dtype is None:
                param = param.to(old_param.dtype)

            if old_param.is_contiguous():
                param = param.contiguous()

        param_device = _determine_param_device(param_name, device_map)
282
283

        # bnb params are flattened.
284
        # gguf quants have a different shape based on the type of quantization applied
285
286
        if empty_state_dict[param_name].shape != param.shape:
            if (
Aryan's avatar
Aryan committed
287
                is_quantized
288
                and hf_quantizer.pre_quantized
289
290
291
                and hf_quantizer.check_if_quantized_param(
                    model, param, param_name, state_dict, param_device=param_device
                )
292
            ):
293
                hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
Aryan's avatar
Aryan committed
294
            else:
295
296
                model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
                raise ValueError(
297
                    f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
298
                )
299
300
301
302
303
304
        if param_device == "disk":
            offload_index = offload_weight(param, param_name, offload_folder, offload_index)
        elif param_device == "cpu" and state_dict_index is not None:
            state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
        elif is_quantized and (
            hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
305
        ):
306
307
308
            hf_quantizer.create_quantized_param(
                model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
            )
309
        else:
310
            set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
311

312
    return offload_index, state_dict_index
313
314


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
    """
    Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
    checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
    parameters.

    """
    if model_to_load.device.type == "meta":
        return False

    if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
        return False

    # Some models explicitly do not support param buffer assignment
    if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
        logger.debug(
            f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
        )
        return False

    # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
    first_key = next(iter(model_to_load.state_dict().keys()))
    if start_prefix + first_key in state_dict:
        return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

    return False


def _load_shard_file(
    shard_file,
    model,
    model_state_dict,
    device_map=None,
    dtype=None,
    hf_quantizer=None,
    keep_in_fp32_modules=None,
    dduf_entries=None,
    loaded_keys=None,
    unexpected_keys=None,
    offload_index=None,
    offload_folder=None,
    state_dict_index=None,
    state_dict_folder=None,
    ignore_mismatched_sizes=False,
    low_cpu_mem_usage=False,
):
    state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
    mismatched_keys = _find_mismatched_keys(
        state_dict,
        model_state_dict,
        loaded_keys,
        ignore_mismatched_sizes,
    )
    error_msgs = []
    if low_cpu_mem_usage:
        offload_index, state_dict_index = load_model_dict_into_meta(
            model,
            state_dict,
            device_map=device_map,
            dtype=dtype,
            hf_quantizer=hf_quantizer,
            keep_in_fp32_modules=keep_in_fp32_modules,
            unexpected_keys=unexpected_keys,
            offload_folder=offload_folder,
            offload_index=offload_index,
            state_dict_index=state_dict_index,
            state_dict_folder=state_dict_folder,
        )
    else:
        assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)

        error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
    return offload_index, state_dict_index, mismatched_keys, error_msgs


def _load_shard_files_with_threadpool(
    shard_files,
    model,
    model_state_dict,
    device_map=None,
    dtype=None,
    hf_quantizer=None,
    keep_in_fp32_modules=None,
    dduf_entries=None,
    loaded_keys=None,
    unexpected_keys=None,
    offload_index=None,
    offload_folder=None,
    state_dict_index=None,
    state_dict_folder=None,
    ignore_mismatched_sizes=False,
    low_cpu_mem_usage=False,
):
    # Do not spawn anymore workers than you need
    num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)

    logger.info(f"Loading model weights in parallel with {num_workers} workers...")

    error_msgs = []
    mismatched_keys = []

    load_one = functools.partial(
        _load_shard_file,
        model=model,
        model_state_dict=model_state_dict,
        device_map=device_map,
        dtype=dtype,
        hf_quantizer=hf_quantizer,
        keep_in_fp32_modules=keep_in_fp32_modules,
        dduf_entries=dduf_entries,
        loaded_keys=loaded_keys,
        unexpected_keys=unexpected_keys,
        offload_index=offload_index,
        offload_folder=offload_folder,
        state_dict_index=state_dict_index,
        state_dict_folder=state_dict_folder,
        ignore_mismatched_sizes=ignore_mismatched_sizes,
        low_cpu_mem_usage=low_cpu_mem_usage,
    )

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
            futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
            for future in as_completed(futures):
                result = future.result()
                offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
                error_msgs += _error_msgs
                mismatched_keys += _mismatched_keys
                pbar.update(1)

    return offload_index, state_dict_index, mismatched_keys, error_msgs


def _find_mismatched_keys(
    state_dict,
    model_state_dict,
    loaded_keys,
    ignore_mismatched_sizes,
):
    mismatched_keys = []
    if ignore_mismatched_sizes:
        for checkpoint_key in loaded_keys:
            model_key = checkpoint_key
            # If the checkpoint is sharded, we may not have the key here.
            if checkpoint_key not in state_dict:
                continue

            if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
                mismatched_keys.append(
                    (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
                )
                del state_dict[checkpoint_key]
    return mismatched_keys


470
471
472
def _load_state_dict_into_model(
    model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
473
474
475
476
477
478
479
    # Convert old format to new format if needed from a PyTorch state_dict
    # copy state_dict so _load_from_state_dict can modify it
    state_dict = state_dict.copy()
    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
480
481
482
483
484
485
    def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
        local_metadata = {}
        local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
        if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
            logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
486
487
488
489
        module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
490
                load(child, prefix + name + ".", assign_to_params_buffers)
491

492
    load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
493
494

    return error_msgs
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510


def _fetch_index_file(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
Marc Sun's avatar
Marc Sun committed
511
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        )
    else:
        index_file_in_repo = Path(
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        ).as_posix()
        try:
            index_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=index_file_in_repo,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
534
                subfolder=None,
535
536
                user_agent=user_agent,
                commit_hash=commit_hash,
Marc Sun's avatar
Marc Sun committed
537
                dduf_entries=dduf_entries,
538
            )
Marc Sun's avatar
Marc Sun committed
539
540
            if not dduf_entries:
                index_file = Path(index_file)
541
542
543
544
        except (EntryNotFoundError, EnvironmentError):
            index_file = None

    return index_file
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560


def _fetch_index_file_legacy(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
Marc Sun's avatar
Marc Sun committed
561
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
        ).as_posix()
        splits = index_file.split(".")
        split_index = -3 if ".cache" in index_file else -2
        splits = splits[:-split_index] + [variant] + splits[-split_index:]
        index_file = ".".join(splits)
        if os.path.exists(index_file):
            deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
            deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            index_file = Path(index_file)
        else:
            index_file = None
    else:
        if variant is not None:
            index_file_in_repo = Path(
                subfolder or "",
                SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
            ).as_posix()
            splits = index_file_in_repo.split(".")
            split_index = -2
            splits = splits[:-split_index] + [variant] + splits[-split_index:]
            index_file_in_repo = ".".join(splits)
            try:
                index_file = _get_model_file(
                    pretrained_model_name_or_path,
                    weights_name=index_file_in_repo,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    revision=revision,
                    subfolder=None,
                    user_agent=user_agent,
                    commit_hash=commit_hash,
Marc Sun's avatar
Marc Sun committed
602
                    dduf_entries=dduf_entries,
603
604
605
606
607
608
609
610
                )
                index_file = Path(index_file)
                deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
                deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            except (EntryNotFoundError, EnvironmentError):
                index_file = None

    return index_file
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685


def _gguf_parse_value(_value, data_type):
    if not isinstance(data_type, list):
        data_type = [data_type]
    if len(data_type) == 1:
        data_type = data_type[0]
        array_data_type = None
    else:
        if data_type[0] != 9:
            raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
        data_type, array_data_type = data_type

    if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
        _value = int(_value[0])
    elif data_type in [6, 12]:
        _value = float(_value[0])
    elif data_type in [7]:
        _value = bool(_value[0])
    elif data_type in [8]:
        _value = array("B", list(_value)).tobytes().decode()
    elif data_type in [9]:
        _value = _gguf_parse_value(_value, array_data_type)
    return _value


def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
    """
    Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
    attributes.

    Args:
        gguf_checkpoint_path (`str`):
            The path the to GGUF file to load
        return_tensors (`bool`, defaults to `True`):
            Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
            metadata in memory.
    """

    if is_gguf_available() and is_torch_available():
        import gguf
        from gguf import GGUFReader

        from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
    else:
        logger.error(
            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
            "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
        )
        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")

    reader = GGUFReader(gguf_checkpoint_path)

    parsed_parameters = {}
    for tensor in reader.tensors:
        name = tensor.name
        quant_type = tensor.tensor_type

        # if the tensor is a torch supported dtype do not use GGUFParameter
        is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
        if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
            _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
            raise ValueError(
                (
                    f"{name} has a quantization type: {str(quant_type)} which is unsupported."
                    "\n\nCurrently the following quantization types are supported: \n\n"
                    f"{_supported_quants_str}"
                    "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
                )
            )

        weights = torch.from_numpy(tensor.data.copy())
        parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights

    return parsed_parameters
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718


def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
    mismatched_keys = []
    if not ignore_mismatched_sizes:
        return mismatched_keys
    for checkpoint_key in loaded_keys:
        model_key = checkpoint_key
        # If the checkpoint is sharded, we may not have the key here.
        if checkpoint_key not in state_dict:
            continue

        if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
            mismatched_keys.append(
                (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
            )
            del state_dict[checkpoint_key]
    return mismatched_keys


def _expand_device_map(device_map, param_names):
    """
    Expand a device map to return the correspondence parameter name to device.
    """
    new_device_map = {}
    for module, device in device_map.items():
        new_device_map.update(
            {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
        )
    return new_device_map


# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
719
720
721
def _caching_allocator_warmup(
    model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
) -> None:
722
723
724
725
726
727
    """
    This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
    device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
    which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
    very large margin.
    """
728
    factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
729
730
731
732
733
734
    # Remove disk and cpu devices, and cast to proper torch.device
    accelerator_device_map = {
        param: torch.device(device)
        for param, device in expanded_device_map.items()
        if str(device) not in ["cpu", "disk"]
    }
735
    total_byte_count = defaultdict(lambda: 0)
736
737
738
739
740
    for param_name, device in accelerator_device_map.items():
        try:
            param = model.get_parameter(param_name)
        except AttributeError:
            param = model.get_buffer(param_name)
741
742
743
744
        # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
        param_byte_count = param.numel() * param.element_size()
        # TODO: account for TP when needed.
        total_byte_count[device] += param_byte_count
745
746

    # This will kick off the caching allocator to avoid having to Malloc afterwards
747
748
    for device, byte_count in total_byte_count.items():
        _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)