# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os from collections import OrderedDict from typing import List, Optional, Union import safetensors import torch from ..utils import ( SAFETENSORS_FILE_EXTENSION, is_accelerate_available, is_torch_version, logging, ) logger = logging.get_logger(__name__) if is_accelerate_available(): from accelerate import infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device # Adapted from `transformers` (see modeling_utils.py) def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): if isinstance(device_map, str): no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} if device_map != "sequential": max_memory = get_balanced_memory( model, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), max_memory=max_memory, **device_map_kwargs, ) else: max_memory = get_max_memory(max_memory) device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) return device_map def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( checkpoint_file, map_location="cpu", **weights_only_kwarg, ) except Exception as e: try: with open(checkpoint_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned." ) else: raise ValueError( f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): raise OSError( f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " ) def load_model_dict_into_meta( model, state_dict: OrderedDict, device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, ) -> List[str]: device = device or torch.device("cpu") dtype = dtype or torch.float32 accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): if param_name not in empty_state_dict: unexpected_keys.append(param_name) continue if empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if accepts_dtype: set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) else: set_module_tensor_to_device(model, param_name, device, value=param) return unexpected_keys def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix: str = ""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model_to_load) return error_msgs