Unverified Commit ed507680 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] don't break offloading for incompatible lora ckpts. (#5085)



* don't break offloading for incompatible lora ckpts.

* debugging

* better condition.

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7974fad1
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import re import re
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from io import BytesIO from io import BytesIO
...@@ -33,7 +32,6 @@ from .utils import ( ...@@ -33,7 +32,6 @@ from .utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_omegaconf_available, is_omegaconf_available,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -308,6 +306,9 @@ class UNet2DConditionLoadersMixin: ...@@ -308,6 +306,9 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None is_network_alphas_none = network_alphas is None
allow_pickle = False allow_pickle = False
...@@ -461,6 +462,7 @@ class UNet2DConditionLoadersMixin: ...@@ -461,6 +462,7 @@ class UNet2DConditionLoadersMixin:
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else: else:
lora.load_state_dict(value_dict) lora.load_state_dict(value_dict)
elif is_custom_diffusion: elif is_custom_diffusion:
attn_processors = {} attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict) custom_diffusion_grouped_dict = defaultdict(dict)
...@@ -490,19 +492,44 @@ class UNet2DConditionLoadersMixin: ...@@ -490,19 +492,44 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
self.set_attn_processor(attn_processors)
else: else:
raise ValueError( raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
) )
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers # set lora layers
for target_module, lora_layer in lora_layers_list: for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer) target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device) self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all( is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
...@@ -1072,26 +1099,21 @@ class LoraLoaderMixin: ...@@ -1072,26 +1099,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
# Remove any existing hooks. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
is_model_cpu_offload = False state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_sequential_cpu_offload = False
recurive = False is_correct_format = all("lora" in key for key in state_dict.keys())
for _, component in self.components.items(): if not is_correct_format:
if isinstance(component, nn.Module): raise ValueError("Invalid LoRA checkpoint.")
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage state_dict,
network_alphas=network_alphas,
unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
) )
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
...@@ -1099,14 +1121,9 @@ class LoraLoaderMixin: ...@@ -1099,14 +1121,9 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
def lora_state_dict( def lora_state_dict(
cls, cls,
...@@ -1403,7 +1420,7 @@ class LoraLoaderMixin: ...@@ -1403,7 +1420,7 @@ class LoraLoaderMixin:
return new_state_dict return new_state_dict
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None): def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -1445,13 +1462,22 @@ class LoraLoaderMixin: ...@@ -1445,13 +1462,22 @@ class LoraLoaderMixin:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only # Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix. # contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) logger.warn(warn_message)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage) unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod @classmethod
def load_lora_into_text_encoder( def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -1561,11 +1587,15 @@ class LoraLoaderMixin: ...@@ -1561,11 +1587,15 @@ class LoraLoaderMixin:
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
# set correct dtype & device is_pipeline_offloaded = _pipeline is not None and any(
text_encoder_lora_state_dict = { isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) )
for k, v in text_encoder_lora_state_dict.items() if is_pipeline_offloaded and low_cpu_mem_usage:
} low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage: if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
...@@ -1581,8 +1611,33 @@ class LoraLoaderMixin: ...@@ -1581,8 +1611,33 @@ class LoraLoaderMixin:
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
) )
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property @property
def lora_scale(self) -> float: def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline. # property function that returns the lora scale which can be set at run time by the pipeline.
...@@ -2652,31 +2707,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2652,31 +2707,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
# Remove any existing hooks. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
else:
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
state_dict, network_alphas = self.lora_state_dict( state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
**kwargs, **kwargs,
) )
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
...@@ -2685,6 +2726,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2685,6 +2726,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
_pipeline=self,
) )
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
...@@ -2695,14 +2737,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2695,14 +2737,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder_2, text_encoder=self.text_encoder_2,
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
_pipeline=self,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment