Unverified Commit 0fa32bd6 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] use loralinear instead of depecrecated lora attn procs. (#5331)

* use loralinear instead of depecrecated lora attn procs.

* fix parameters()

* fix saving

* add back support for add kv proj.

* fix: param accumul,ation.

* propagate the changes.
parent aea73834
......@@ -24,7 +24,6 @@ import os
import shutil
import warnings
from pathlib import Path
from typing import Dict
import numpy as np
import torch
......@@ -59,12 +58,11 @@ from diffusers.loaders import (
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
......@@ -658,22 +656,6 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
......@@ -858,33 +840,60 @@ def main(args):
# => 32 layers
# Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
......@@ -902,7 +911,7 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
......@@ -1338,7 +1347,7 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet)
unet_lora_layers = unet_lora_state_dict(unet)
if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
......
......@@ -23,7 +23,6 @@ import os
import shutil
import warnings
from pathlib import Path
from typing import Dict
import numpy as np
import torch
......@@ -51,8 +50,9 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
......@@ -575,22 +575,6 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
return prompt_embeds, pooled_prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
......@@ -761,29 +745,42 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
......@@ -807,7 +804,7 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
......@@ -1274,7 +1271,7 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet)
unet_lora_layers = unet_lora_state_dict(unet)
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
......
......@@ -50,7 +50,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
......@@ -609,29 +609,42 @@ def main(args):
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
......
......@@ -1647,7 +1647,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
......@@ -1697,7 +1697,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
......
......@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, Optional, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available
......@@ -52,6 +53,25 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns:
A state dict containing just the LoRA parameters.
"""
lora_state_dict = {}
for name, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
lora_layer = getattr(module, "lora_layer")
if lora_layer is not None:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment