Unverified Commit f9225301 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353)



* Improve checkpointing lora

* fix more

* Improve doc string

* Update src/diffusers/loaders.py

* make stytle

* Apply suggestions from code review

* Update src/diffusers/loaders.py

* Apply suggestions from code review

* Apply suggestions from code review

* better

* Fix all

* Fix multi-GPU dreambooth

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Fix all

* make style

* make style

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 58c6f9cb
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
import warnings import warnings
from pathlib import Path from pathlib import Path
import accelerate
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -733,36 +732,34 @@ def main(args): ...@@ -733,36 +732,34 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
# `accelerate` 0.16.0 will have better support for customized saving # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): def save_model_hook(models, weights, output_dir):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format for model in models:
def save_model_hook(models, weights, output_dir): sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
for model in models: model.save_pretrained(os.path.join(output_dir, sub_dir))
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again
weights.pop()
# make sure to pop weight so that corresponding model is not saved again
weights.pop() def load_model_hook(models, input_dir):
while len(models) > 0:
def load_model_hook(models, input_dir): # pop models so that they are not loaded again
while len(models) > 0: model = models.pop()
# pop models so that they are not loaded again
model = models.pop() if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
# load transformers style into model
if type(model) == type(text_encoder): load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
# load transformers style into model model.config = load_model.config
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") else:
model.config = load_model.config # load diffusers style into model
else: load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
# load diffusers style into model model.register_to_config(**load_model.config)
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict()) model.load_state_dict(load_model.state_dict())
del load_model del load_model
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False) vae.requires_grad_(False)
if not args.train_text_encoder: if not args.train_text_encoder:
......
...@@ -834,7 +834,6 @@ def main(args): ...@@ -834,7 +834,6 @@ def main(args):
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors) unet_lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(unet_lora_layers)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this, # So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
...@@ -853,9 +852,68 @@ def main(args): ...@@ -853,9 +852,68 @@ def main(args):
) )
temp_pipeline._modify_text_encoder(text_lora_attn_procs) temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder text_encoder = temp_pipeline.text_encoder
accelerator.register_for_checkpointing(text_encoder_lora_layers)
del temp_pipeline del temp_pipeline
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
for model in models:
state_dict = model.state_dict()
if (
text_encoder_lora_layers is not None
and text_encoder_keys is not None
and state_dict.keys() == text_encoder_keys
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose
# so that the we don't accidentally override the LoRA layers of
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
# load lora weights into models
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
# delete temporary pipeline and pop models
del temp_pipeline
for _ in range(len(models)):
models.pop()
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: if args.allow_tf32:
...@@ -1130,17 +1188,10 @@ def main(args): ...@@ -1130,17 +1188,10 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process:
if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# We combine the text encoder and UNet LoRA parameters with a simple accelerator.save_state(save_path)
# custom logic. `accelerator.save_state()` won't know that. So,
# use `LoraLoaderMixin.save_lora_weights()`.
LoraLoaderMixin.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -1217,8 +1268,12 @@ def main(args): ...@@ -1217,8 +1268,12 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
if text_encoder is not None: if text_encoder is not None:
text_encoder = text_encoder.to(torch.float32) text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
LoraLoaderMixin.save_lora_weights( LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers, unet_lora_layers=unet_lora_layers,
...@@ -1250,6 +1305,7 @@ def main(args): ...@@ -1250,6 +1305,7 @@ def main(args):
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [ images = [
......
...@@ -70,6 +70,9 @@ class AttnProcsLayers(torch.nn.Module): ...@@ -70,6 +70,9 @@ class AttnProcsLayers(torch.nn.Module):
self.mapping = dict(enumerate(state_dict.keys())) self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
# we add a hook to state_dict() and load_state_dict() so that the # we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors` # naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs): def map_to(module, state_dict, *args, **kwargs):
...@@ -81,10 +84,19 @@ class AttnProcsLayers(torch.nn.Module): ...@@ -81,10 +84,19 @@ class AttnProcsLayers(torch.nn.Module):
return new_state_dict return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
raise ValueError(
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
)
def map_from(module, state_dict, *args, **kwargs): def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys()) all_keys = list(state_dict.keys())
for key in all_keys: for key in all_keys:
replace_key = key.split(".processor")[0] + ".processor" replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key] state_dict[new_key] = state_dict[key]
del state_dict[key] del state_dict[key]
...@@ -898,6 +910,9 @@ class LoraLoaderMixin: ...@@ -898,6 +910,9 @@ class LoraLoaderMixin:
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder) self._modify_text_encoder(attn_procs_text_encoder)
# save lora attn procs of text encoder so that it can be easily retrieved
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
# Otherwise, we're dealing with the old format. This means the `state_dict` should only # Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix. # contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all( elif not all(
...@@ -907,6 +922,12 @@ class LoraLoaderMixin: ...@@ -907,6 +922,12 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) warnings.warn(warn_message)
@property
def text_encoder_lora_attn_procs(self):
if hasattr(self, "_text_encoder_lora_attn_procs"):
return self._text_encoder_lora_attn_procs
return
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r""" r"""
Monkey-patches the forward passes of attention modules of the text encoder. Monkey-patches the forward passes of attention modules of the text encoder.
...@@ -1110,7 +1131,7 @@ class LoraLoaderMixin: ...@@ -1110,7 +1131,7 @@ class LoraLoaderMixin:
def save_lora_weights( def save_lora_weights(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, torch.nn.Module] = None, unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
...@@ -1123,13 +1144,14 @@ class LoraLoaderMixin: ...@@ -1123,13 +1144,14 @@ class LoraLoaderMixin:
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module`]): unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
serialization process easier and cleaner. serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): weights.
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
dict. dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
...@@ -1157,15 +1179,22 @@ class LoraLoaderMixin: ...@@ -1157,15 +1179,22 @@ class LoraLoaderMixin:
# Create a flat dictionary. # Create a flat dictionary.
state_dict = {} state_dict = {}
if unet_lora_layers is not None: if unet_lora_layers is not None:
unet_lora_state_dict = { weights = (
f"{self.unet_name}.{module_name}": param unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
for module_name, param in unet_lora_layers.state_dict().items() )
}
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict) state_dict.update(unet_lora_state_dict)
if text_encoder_lora_layers is not None: if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
text_encoder_lora_state_dict = { text_encoder_lora_state_dict = {
f"{self.text_encoder_name}.{module_name}": param f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
for module_name, param in text_encoder_lora_layers.state_dict().items()
} }
state_dict.update(text_encoder_lora_state_dict) state_dict.update(text_encoder_lora_state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment