model_loading_utils.py 21.6 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2025 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import importlib
18
19
import inspect
import os
20
from array import array
21
from collections import OrderedDict
22
from pathlib import Path
23
24
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
25
26
27

import safetensors
import torch
Marc Sun's avatar
Marc Sun committed
28
from huggingface_hub import DDUFEntry
29
from huggingface_hub.utils import EntryNotFoundError
30

31
from ..quantizers import DiffusersQuantizer
32
from ..utils import (
33
    GGUF_FILE_EXTENSION,
34
    SAFE_WEIGHTS_INDEX_NAME,
35
    SAFETENSORS_FILE_EXTENSION,
36
37
38
    WEIGHTS_INDEX_NAME,
    _add_variant,
    _get_model_file,
39
    deprecate,
40
    is_accelerate_available,
41
42
    is_gguf_available,
    is_torch_available,
43
44
45
46
47
48
49
    is_torch_version,
    logging,
)


logger = logging.get_logger(__name__)

50
51
52
53
54
55
56
_CLASS_REMAPPING_DICT = {
    "Transformer2DModel": {
        "ada_norm_zero": "DiTTransformer2DModel",
        "ada_norm_single": "PixArtTransformer2DModel",
    }
}

57
58
59

if is_accelerate_available():
    from accelerate import infer_auto_device_map
60
    from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
61
62
63


# Adapted from `transformers` (see modeling_utils.py)
64
65
66
def _determine_device_map(
    model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
):
67
    if isinstance(device_map, str):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        special_dtypes = {}
        if hf_quantizer is not None:
            special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
        special_dtypes.update(
            {
                name: torch.float32
                for name, _ in model.named_parameters()
                if any(m in name for m in keep_in_fp32_modules)
            }
        )

        target_dtype = torch_dtype
        if hf_quantizer is not None:
            target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)

83
84
85
        no_split_modules = model._get_no_split_modules(device_map)
        device_map_kwargs = {"no_split_module_classes": no_split_modules}

86
87
88
89
90
91
92
93
        if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
            device_map_kwargs["special_dtypes"] = special_dtypes
        elif len(special_dtypes) > 0:
            logger.warning(
                "This model has some weights that should be kept in higher precision, you need to upgrade "
                "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
            )

94
95
96
97
98
99
100
101
102
103
104
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                model,
                dtype=torch_dtype,
                low_zero=(device_map == "balanced_low_0"),
                max_memory=max_memory,
                **device_map_kwargs,
            )
        else:
            max_memory = get_max_memory(max_memory)

105
106
107
        if hf_quantizer is not None:
            max_memory = hf_quantizer.adjust_max_memory(max_memory)

108
        device_map_kwargs["max_memory"] = max_memory
109
110
111
112
        device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

        if hf_quantizer is not None:
            hf_quantizer.validate_environment(device_map=device_map)
113
114
115
116

    return device_map


117
118
def _fetch_remapped_cls_from_config(config, old_class):
    previous_class_name = old_class.__name__
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)

    # Details:
    # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
    if remapped_class_name:
        # load diffusers library to import compatible and original scheduler
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        remapped_class = getattr(diffusers_library, remapped_class_name)
        logger.info(
            f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
            f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
            " DOESN'T affect the final results."
        )
        return remapped_class
    else:
        return old_class
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
    """
    Check format of the archive
    """
    with safetensors.safe_open(checkpoint_file, framework="pt") as f:
        metadata = f.metadata()
        if metadata is not None and metadata.get("format") not in format_list:
            raise OSError(
                f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
                "you save your model with the `save_pretrained` method."
            )


def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
    """
    Find the device of param_name from the device_map.
    """
    if device_map is None:
        return "cpu"
    else:
        module_name = param_name
        # find next higher level module that is defined in device_map:
        # bert.lm_head.weight -> bert.lm_head -> bert -> ''
        while len(module_name) > 0 and module_name not in device_map:
            module_name = ".".join(module_name.split(".")[:-1])
        if module_name == "" and "" not in device_map:
            raise ValueError(f"{param_name} doesn't have any device set.")
        return device_map[module_name]


167
def load_state_dict(
Marc Sun's avatar
Marc Sun committed
168
169
170
    checkpoint_file: Union[str, os.PathLike],
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
    disable_mmap: bool = False,
171
    map_location: Union[str, torch.device] = "cpu",
172
):
173
174
175
    """
    Reads a checkpoint file, returning properly formatted errors if they arise.
    """
176
    # TODO: maybe refactor a bit this part where we pass a dict here
177
178
    if isinstance(checkpoint_file, dict):
        return checkpoint_file
179
180
181
    try:
        file_extension = os.path.basename(checkpoint_file).split(".")[-1]
        if file_extension == SAFETENSORS_FILE_EXTENSION:
Marc Sun's avatar
Marc Sun committed
182
183
184
185
            if dduf_entries:
                # tensors are loaded on cpu
                with dduf_entries[checkpoint_file].as_mmap() as mm:
                    return safetensors.torch.load(mm)
186
            _check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
187
188
189
            if disable_mmap:
                return safetensors.torch.load(open(checkpoint_file, "rb").read())
            else:
190
                return safetensors.torch.load_file(checkpoint_file, device=map_location)
191
192
        elif file_extension == GGUF_FILE_EXTENSION:
            return load_gguf_checkpoint(checkpoint_file)
193
        else:
194
            extra_args = {}
195
            weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
196
197
198
199
200
201
202
203
204
205
            # mmap can only be used with files serialized with zipfile-based format.
            if (
                isinstance(checkpoint_file, str)
                and map_location != "meta"
                and is_torch_version(">=", "2.1.0")
                and is_zipfile(checkpoint_file)
                and not disable_mmap
            ):
                extra_args = {"mmap": True}
            return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    except Exception as e:
        try:
            with open(checkpoint_file) as f:
                if f.read().startswith("version"):
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    raise ValueError(
                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
                        "model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            raise OSError(
                f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
            )


def load_model_dict_into_meta(
    model,
    state_dict: OrderedDict,
    dtype: Optional[Union[str, torch.dtype]] = None,
    model_name_or_path: Optional[str] = None,
231
232
233
234
235
236
237
238
    hf_quantizer: Optional[DiffusersQuantizer] = None,
    keep_in_fp32_modules: Optional[List] = None,
    device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
    unexpected_keys: Optional[List[str]] = None,
    offload_folder: Optional[Union[str, os.PathLike]] = None,
    offload_index: Optional[Dict] = None,
    state_dict_index: Optional[Dict] = None,
    state_dict_folder: Optional[Union[str, os.PathLike]] = None,
239
) -> List[str]:
240
241
242
243
    """
    This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
    params on a `meta` device. It replaces the model params with the data from the `state_dict`
    """
244

245
    is_quantized = hf_quantizer is not None
246
    empty_state_dict = model.state_dict()
247

248
249
250
251
    for param_name, param in state_dict.items():
        if param_name not in empty_state_dict:
            continue

252
253
254
255
        set_module_kwargs = {}
        # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
        # in int/uint/bool and not cast them.
        # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
256
257
258
        if dtype is not None and torch.is_floating_point(param):
            if keep_in_fp32_modules is not None and any(
                module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
259
260
            ):
                param = param.to(torch.float32)
261
                set_module_kwargs["dtype"] = torch.float32
262
263
            else:
                param = param.to(dtype)
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                set_module_kwargs["dtype"] = dtype

        # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
        # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
        # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
        old_param = model
        splits = param_name.split(".")
        for split in splits:
            old_param = getattr(old_param, split)

        if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
            old_param = None

        if old_param is not None:
            if dtype is None:
                param = param.to(old_param.dtype)

            if old_param.is_contiguous():
                param = param.contiguous()

        param_device = _determine_param_device(param_name, device_map)
285
286

        # bnb params are flattened.
287
        # gguf quants have a different shape based on the type of quantization applied
288
289
        if empty_state_dict[param_name].shape != param.shape:
            if (
Aryan's avatar
Aryan committed
290
                is_quantized
291
                and hf_quantizer.pre_quantized
292
293
294
                and hf_quantizer.check_if_quantized_param(
                    model, param, param_name, state_dict, param_device=param_device
                )
295
            ):
296
                hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
Aryan's avatar
Aryan committed
297
            else:
298
299
                model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
                raise ValueError(
300
                    f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
301
                )
302
303
304
305
306
307
        if param_device == "disk":
            offload_index = offload_weight(param, param_name, offload_folder, offload_index)
        elif param_device == "cpu" and state_dict_index is not None:
            state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
        elif is_quantized and (
            hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
308
        ):
309
            hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
310
        else:
311
            set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
312

313
    return offload_index, state_dict_index
314
315


316
317
318
def _load_state_dict_into_model(
    model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
319
320
321
322
323
324
325
    # Convert old format to new format if needed from a PyTorch state_dict
    # copy state_dict so _load_from_state_dict can modify it
    state_dict = state_dict.copy()
    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
326
327
328
329
330
331
    def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
        local_metadata = {}
        local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
        if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
            logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
332
333
334
335
        module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
336
                load(child, prefix + name + ".", assign_to_params_buffers)
337

338
    load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
339
340

    return error_msgs
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356


def _fetch_index_file(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
Marc Sun's avatar
Marc Sun committed
357
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        )
    else:
        index_file_in_repo = Path(
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        ).as_posix()
        try:
            index_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=index_file_in_repo,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
380
                subfolder=None,
381
382
                user_agent=user_agent,
                commit_hash=commit_hash,
Marc Sun's avatar
Marc Sun committed
383
                dduf_entries=dduf_entries,
384
            )
Marc Sun's avatar
Marc Sun committed
385
386
            if not dduf_entries:
                index_file = Path(index_file)
387
388
389
390
        except (EntryNotFoundError, EnvironmentError):
            index_file = None

    return index_file
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406


def _fetch_index_file_legacy(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
Marc Sun's avatar
Marc Sun committed
407
    dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
        ).as_posix()
        splits = index_file.split(".")
        split_index = -3 if ".cache" in index_file else -2
        splits = splits[:-split_index] + [variant] + splits[-split_index:]
        index_file = ".".join(splits)
        if os.path.exists(index_file):
            deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
            deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            index_file = Path(index_file)
        else:
            index_file = None
    else:
        if variant is not None:
            index_file_in_repo = Path(
                subfolder or "",
                SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
            ).as_posix()
            splits = index_file_in_repo.split(".")
            split_index = -2
            splits = splits[:-split_index] + [variant] + splits[-split_index:]
            index_file_in_repo = ".".join(splits)
            try:
                index_file = _get_model_file(
                    pretrained_model_name_or_path,
                    weights_name=index_file_in_repo,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    revision=revision,
                    subfolder=None,
                    user_agent=user_agent,
                    commit_hash=commit_hash,
Marc Sun's avatar
Marc Sun committed
448
                    dduf_entries=dduf_entries,
449
450
451
452
453
454
455
456
                )
                index_file = Path(index_file)
                deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
                deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            except (EntryNotFoundError, EnvironmentError):
                index_file = None

    return index_file
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531


def _gguf_parse_value(_value, data_type):
    if not isinstance(data_type, list):
        data_type = [data_type]
    if len(data_type) == 1:
        data_type = data_type[0]
        array_data_type = None
    else:
        if data_type[0] != 9:
            raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
        data_type, array_data_type = data_type

    if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
        _value = int(_value[0])
    elif data_type in [6, 12]:
        _value = float(_value[0])
    elif data_type in [7]:
        _value = bool(_value[0])
    elif data_type in [8]:
        _value = array("B", list(_value)).tobytes().decode()
    elif data_type in [9]:
        _value = _gguf_parse_value(_value, array_data_type)
    return _value


def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
    """
    Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
    attributes.

    Args:
        gguf_checkpoint_path (`str`):
            The path the to GGUF file to load
        return_tensors (`bool`, defaults to `True`):
            Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
            metadata in memory.
    """

    if is_gguf_available() and is_torch_available():
        import gguf
        from gguf import GGUFReader

        from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
    else:
        logger.error(
            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
            "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
        )
        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")

    reader = GGUFReader(gguf_checkpoint_path)

    parsed_parameters = {}
    for tensor in reader.tensors:
        name = tensor.name
        quant_type = tensor.tensor_type

        # if the tensor is a torch supported dtype do not use GGUFParameter
        is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
        if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
            _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
            raise ValueError(
                (
                    f"{name} has a quantization type: {str(quant_type)} which is unsupported."
                    "\n\nCurrently the following quantization types are supported: \n\n"
                    f"{_supported_quants_str}"
                    "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
                )
            )

        weights = torch.from_numpy(tensor.data.copy())
        parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights

    return parsed_parameters