Unverified Commit 37cb819d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Lora] Speed up lora loading (#4994)

* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
parent f64d52db
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
import re import re
import warnings import warnings
...@@ -27,6 +26,7 @@ import torch ...@@ -27,6 +26,7 @@ import torch
from huggingface_hub import hf_hub_download, model_info from huggingface_hub import hf_hub_download, model_info
from torch import nn from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
...@@ -46,7 +46,6 @@ if is_transformers_available(): ...@@ -46,7 +46,6 @@ if is_transformers_available():
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module): ...@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module):
self.w_down = None self.w_down = None
def forward(self, input): def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None: if self.lora_scale is None:
self.lora_scale = 1.0 self.lora_scale = 1.0
if self.lora_linear_layer is None: if self.lora_linear_layer is None:
...@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin: ...@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin:
use_auth_token (`str` or *bool*, *optional*): use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git. allowed by Git.
...@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin: ...@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
...@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin: ...@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin:
"framework": "pytorch", "framework": "pytorch",
} }
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
...@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin: ...@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin:
# correct keys # correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
if network_alphas is not None:
network_alphas_keys = list(network_alphas.keys())
used_network_alphas_keys = set()
lora_grouped_dict = defaultdict(dict) lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {} mapped_network_alphas = {}
...@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin: ...@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin:
# Create another `mapped_network_alphas` dictionary so that we can properly map them. # Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None: if network_alphas is not None:
network_alphas_ = copy.deepcopy(network_alphas) for k in network_alphas_keys:
for k in network_alphas_:
if k.replace(".alpha", "") in key: if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)}) mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none: if not is_network_alphas_none:
if len(network_alphas) > 0: if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError( raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
) )
...@@ -411,6 +428,8 @@ class UNet2DConditionLoadersMixin: ...@@ -411,6 +428,8 @@ class UNet2DConditionLoadersMixin:
out_features = attn_processor.out_channels out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size kernel_size = attn_processor.kernel_size
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer( lora = LoRAConv2dLayer(
in_features=in_features, in_features=in_features,
out_features=out_features, out_features=out_features,
...@@ -421,6 +440,8 @@ class UNet2DConditionLoadersMixin: ...@@ -421,6 +440,8 @@ class UNet2DConditionLoadersMixin:
network_alpha=mapped_network_alphas.get(key), network_alpha=mapped_network_alphas.get(key),
) )
elif isinstance(attn_processor, LoRACompatibleLinear): elif isinstance(attn_processor, LoRACompatibleLinear):
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer( lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.in_features,
attn_processor.out_features, attn_processor.out_features,
...@@ -431,9 +452,14 @@ class UNet2DConditionLoadersMixin: ...@@ -431,9 +452,14 @@ class UNet2DConditionLoadersMixin:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
lora_layers_list.append((attn_processor, lora)) lora_layers_list.append((attn_processor, lora))
if low_cpu_mem_usage:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion: elif is_custom_diffusion:
attn_processors = {} attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict) custom_diffusion_grouped_dict = defaultdict(dict)
...@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin: ...@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin:
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
) )
# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
# set lora layers # set lora layers
for target_module, lora_layer in lora_layers_list: for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer) target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all( is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
...@@ -999,13 +1024,18 @@ class LoraLoaderMixin: ...@@ -999,13 +1024,18 @@ class LoraLoaderMixin:
recurive = is_sequential_cpu_offload recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive) remove_hook_from_module(component, recurse=recurive)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
)
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
# Offload back. # Offload back.
...@@ -1065,6 +1095,11 @@ class LoraLoaderMixin: ...@@ -1065,6 +1095,11 @@ class LoraLoaderMixin:
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*): mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
...@@ -1305,7 +1340,7 @@ class LoraLoaderMixin: ...@@ -1305,7 +1340,7 @@ class LoraLoaderMixin:
return new_state_dict return new_state_dict
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet): def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -1318,7 +1353,13 @@ class LoraLoaderMixin: ...@@ -1318,7 +1353,13 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
""" """
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
...@@ -1343,11 +1384,12 @@ class LoraLoaderMixin: ...@@ -1343,11 +1384,12 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) warnings.warn(warn_message)
# load loras into unet unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
@classmethod @classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -1364,7 +1406,13 @@ class LoraLoaderMixin: ...@@ -1364,7 +1406,13 @@ class LoraLoaderMixin:
lora_scale (`float`): lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer. lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
""" """
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
...@@ -1447,6 +1495,7 @@ class LoraLoaderMixin: ...@@ -1447,6 +1495,7 @@ class LoraLoaderMixin:
network_alphas, network_alphas,
rank=rank, rank=rank,
patch_mlp=patch_mlp, patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
# set correct dtype & device # set correct dtype & device
...@@ -1454,12 +1503,23 @@ class LoraLoaderMixin: ...@@ -1454,12 +1503,23 @@ class LoraLoaderMixin:
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items() for k, v in text_encoder_lora_state_dict.items()
} }
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0: unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError( raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
) )
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@property @property
def lora_scale(self) -> float: def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline. # property function that returns the lora scale which can be set at run time by the pipeline.
...@@ -1492,11 +1552,21 @@ class LoraLoaderMixin: ...@@ -1492,11 +1552,21 @@ class LoraLoaderMixin:
rank: Union[Dict[str, int], int] = 4, rank: Union[Dict[str, int], int] = 4,
dtype=None, dtype=None,
patch_mlp=False, patch_mlp=False,
low_cpu_mem_usage=False,
): ):
r""" r"""
Monkey-patches the forward passes of attention modules of the text encoder. Monkey-patches the forward passes of attention modules of the text encoder.
""" """
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
lora_parameters.extend(model.lora_linear_layer.parameters())
return model
# First, remove any monkey-patch that might have been applied before # First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
...@@ -1515,45 +1585,18 @@ class LoraLoaderMixin: ...@@ -1515,45 +1585,18 @@ class LoraLoaderMixin:
else: else:
current_rank = rank current_rank = rank
q_linear_layer = ( attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj.regular_linear_layer attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
if isinstance(attn_module.q_proj, PatchedLoraProjection)
else attn_module.q_proj
)
attn_module.q_proj = PatchedLoraProjection(
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
k_linear_layer = (
attn_module.k_proj.regular_linear_layer
if isinstance(attn_module.k_proj, PatchedLoraProjection)
else attn_module.k_proj
)
attn_module.k_proj = PatchedLoraProjection(
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
v_linear_layer = (
attn_module.v_proj.regular_linear_layer
if isinstance(attn_module.v_proj, PatchedLoraProjection)
else attn_module.v_proj
) )
attn_module.v_proj = PatchedLoraProjection( attn_module.k_proj = create_patched_linear_lora(
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
) )
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
out_linear_layer = (
attn_module.out_proj.regular_linear_layer
if isinstance(attn_module.out_proj, PatchedLoraProjection)
else attn_module.out_proj
) )
attn_module.out_proj = PatchedLoraProjection( attn_module.out_proj = create_patched_linear_lora(
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
) )
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp: if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder): for name, mlp_module in text_encoder_mlp_modules(text_encoder):
...@@ -1563,25 +1606,12 @@ class LoraLoaderMixin: ...@@ -1563,25 +1606,12 @@ class LoraLoaderMixin:
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
fc1_linear_layer = ( mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1.regular_linear_layer mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
if isinstance(mlp_module.fc1, PatchedLoraProjection)
else mlp_module.fc1
)
mlp_module.fc1 = PatchedLoraProjection(
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
fc2_linear_layer = (
mlp_module.fc2.regular_linear_layer
if isinstance(mlp_module.fc2, PatchedLoraProjection)
else mlp_module.fc2
) )
mlp_module.fc2 = PatchedLoraProjection( mlp_module.fc2 = create_patched_linear_lora(
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
) )
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
if is_network_alphas_populated and len(network_alphas) > 0: if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError( raise ValueError(
...@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin: ...@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin:
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
if is_accelerate_available(): if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items(): load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else: else:
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
......
...@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
) )
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict): def _load_state_dict_into_model(model_to_load, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
...@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)
if param_name not in empty_state_dict: unexpected_keys = load_model_dict_into_meta(
unexpected_keys.append(param_name) model,
continue state_dict,
device=param_device,
if empty_state_dict[param_name].shape != param.shape: dtype=torch_dtype,
raise ValueError( model_name_or_path=pretrained_model_name_or_path,
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
) )
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
if cls._keys_to_ignore_on_load_unexpected is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected: for pat in cls._keys_to_ignore_on_load_unexpected:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment