model_loading_utils.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import importlib
18
19
20
import inspect
import os
from collections import OrderedDict
21
from pathlib import Path
22
23
24
25
from typing import List, Optional, Union

import safetensors
import torch
26
from huggingface_hub.utils import EntryNotFoundError
27

28
from ..quantizers.quantization_config import QuantizationMethod
29
from ..utils import (
30
    SAFE_WEIGHTS_INDEX_NAME,
31
    SAFETENSORS_FILE_EXTENSION,
32
33
34
    WEIGHTS_INDEX_NAME,
    _add_variant,
    _get_model_file,
35
    deprecate,
36
37
38
39
40
41
42
43
    is_accelerate_available,
    is_torch_version,
    logging,
)


logger = logging.get_logger(__name__)

44
45
46
47
48
49
50
_CLASS_REMAPPING_DICT = {
    "Transformer2DModel": {
        "ada_norm_zero": "DiTTransformer2DModel",
        "ada_norm_single": "PixArtTransformer2DModel",
    }
}

51
52
53
54
55
56
57

if is_accelerate_available():
    from accelerate import infer_auto_device_map
    from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device


# Adapted from `transformers` (see modeling_utils.py)
58
59
60
def _determine_device_map(
    model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
):
61
    if isinstance(device_map, str):
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        special_dtypes = {}
        if hf_quantizer is not None:
            special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
        special_dtypes.update(
            {
                name: torch.float32
                for name, _ in model.named_parameters()
                if any(m in name for m in keep_in_fp32_modules)
            }
        )

        target_dtype = torch_dtype
        if hf_quantizer is not None:
            target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)

77
78
79
        no_split_modules = model._get_no_split_modules(device_map)
        device_map_kwargs = {"no_split_module_classes": no_split_modules}

80
81
82
83
84
85
86
87
        if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
            device_map_kwargs["special_dtypes"] = special_dtypes
        elif len(special_dtypes) > 0:
            logger.warning(
                "This model has some weights that should be kept in higher precision, you need to upgrade "
                "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
            )

88
89
90
91
92
93
94
95
96
97
98
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                model,
                dtype=torch_dtype,
                low_zero=(device_map == "balanced_low_0"),
                max_memory=max_memory,
                **device_map_kwargs,
            )
        else:
            max_memory = get_max_memory(max_memory)

99
100
101
        if hf_quantizer is not None:
            max_memory = hf_quantizer.adjust_max_memory(max_memory)

102
        device_map_kwargs["max_memory"] = max_memory
103
104
105
106
        device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

        if hf_quantizer is not None:
            hf_quantizer.validate_environment(device_map=device_map)
107
108
109
110

    return device_map


111
112
def _fetch_remapped_cls_from_config(config, old_class):
    previous_class_name = old_class.__name__
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)

    # Details:
    # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
    if remapped_class_name:
        # load diffusers library to import compatible and original scheduler
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        remapped_class = getattr(diffusers_library, remapped_class_name)
        logger.info(
            f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
            f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
            " DOESN'T affect the final results."
        )
        return remapped_class
    else:
        return old_class
129
130


131
132
133
134
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
    """
    Reads a checkpoint file, returning properly formatted errors if they arise.
    """
135
136
137
138
    # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
    # when refactoring the _merge_sharded_checkpoints() method later.
    if isinstance(checkpoint_file, dict):
        return checkpoint_file
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    try:
        file_extension = os.path.basename(checkpoint_file).split(".")[-1]
        if file_extension == SAFETENSORS_FILE_EXTENSION:
            return safetensors.torch.load_file(checkpoint_file, device="cpu")
        else:
            weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
            return torch.load(
                checkpoint_file,
                map_location="cpu",
                **weights_only_kwarg,
            )
    except Exception as e:
        try:
            with open(checkpoint_file) as f:
                if f.read().startswith("version"):
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )
                else:
                    raise ValueError(
                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
                        "model. Make sure you have saved the model properly."
                    ) from e
        except (UnicodeDecodeError, ValueError):
            raise OSError(
                f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
            )


def load_model_dict_into_meta(
    model,
    state_dict: OrderedDict,
    device: Optional[Union[str, torch.device]] = None,
    dtype: Optional[Union[str, torch.dtype]] = None,
    model_name_or_path: Optional[str] = None,
176
177
    hf_quantizer=None,
    keep_in_fp32_modules=None,
178
) -> List[str]:
179
180
    if device is not None and not isinstance(device, (str, torch.device)):
        raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
181
182
    if hf_quantizer is None:
        device = device or torch.device("cpu")
183
    dtype = dtype or torch.float32
184
185
    is_quantized = hf_quantizer is not None
    is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
186
187
188

    accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
    empty_state_dict = model.state_dict()
189
190
    unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]

191
192
193
194
    for param_name, param in state_dict.items():
        if param_name not in empty_state_dict:
            continue

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        set_module_kwargs = {}
        # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
        # in int/uint/bool and not cast them.
        # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
        if torch.is_floating_point(param):
            if (
                keep_in_fp32_modules is not None
                and any(
                    module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
                )
                and dtype == torch.float16
            ):
                param = param.to(torch.float32)
                if accepts_dtype:
                    set_module_kwargs["dtype"] = torch.float32
            else:
                param = param.to(dtype)
                if accepts_dtype:
                    set_module_kwargs["dtype"] = dtype

        # bnb params are flattened.
216
217
218
219
220
221
222
223
224
225
226
227
        if empty_state_dict[param_name].shape != param.shape:
            if (
                is_quant_method_bnb
                and hf_quantizer.pre_quantized
                and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
            ):
                hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
            elif not is_quant_method_bnb:
                model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
                raise ValueError(
                    f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
                )
228

229
230
        if is_quantized and (
            hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
231
        ):
232
233
            hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
        else:
234
235
236
237
238
            if accepts_dtype:
                set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
            else:
                set_module_tensor_to_device(model, param_name, device, value=param)

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    return unexpected_keys


def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
    # Convert old format to new format if needed from a PyTorch state_dict
    # copy state_dict so _load_from_state_dict can modify it
    state_dict = state_dict.copy()
    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module: torch.nn.Module, prefix: str = ""):
        args = (state_dict, prefix, {}, True, [], [], error_msgs)
        module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    load(model_to_load)

    return error_msgs
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298


def _fetch_index_file(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        )
    else:
        index_file_in_repo = Path(
            subfolder or "",
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        ).as_posix()
        try:
            index_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=index_file_in_repo,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
299
                subfolder=None,
300
301
302
303
304
305
306
307
                user_agent=user_agent,
                commit_hash=commit_hash,
            )
            index_file = Path(index_file)
        except (EntryNotFoundError, EnvironmentError):
            index_file = None

    return index_file
308
309


310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
    weight_map = sharded_metadata.get("weight_map", None)
    if weight_map is None:
        raise KeyError("'weight_map' key not found in the shard index file.")

    # Collect all unique safetensors files from weight_map
    files_to_load = set(weight_map.values())
    is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
    merged_state_dict = {}

    # Load tensors from each unique file
    for file_name in files_to_load:
        part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
        if not os.path.exists(part_file_path):
            raise FileNotFoundError(f"Part file {file_name} not found.")

        if is_safetensors:
            with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
                for tensor_key in f.keys():
                    if tensor_key in weight_map:
                        merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
        else:
            merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

    return merged_state_dict


339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def _fetch_index_file_legacy(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
):
    if is_local:
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",
            SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
        ).as_posix()
        splits = index_file.split(".")
        split_index = -3 if ".cache" in index_file else -2
        splits = splits[:-split_index] + [variant] + splits[-split_index:]
        index_file = ".".join(splits)
        if os.path.exists(index_file):
            deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
            deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            index_file = Path(index_file)
        else:
            index_file = None
    else:
        if variant is not None:
            index_file_in_repo = Path(
                subfolder or "",
                SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
            ).as_posix()
            splits = index_file_in_repo.split(".")
            split_index = -2
            splits = splits[:-split_index] + [variant] + splits[-split_index:]
            index_file_in_repo = ".".join(splits)
            try:
                index_file = _get_model_file(
                    pretrained_model_name_or_path,
                    weights_name=index_file_in_repo,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    revision=revision,
                    subfolder=None,
                    user_agent=user_agent,
                    commit_hash=commit_hash,
                )
                index_file = Path(index_file)
                deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
                deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
            except (EntryNotFoundError, EnvironmentError):
                index_file = None

    return index_file