Unverified Commit 0fa32bd6 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] use loralinear instead of depecrecated lora attn procs. (#5331)

* use loralinear instead of depecrecated lora attn procs.

* fix parameters()

* fix saving

* add back support for add kv proj.

* fix: param accumul,ation.

* propagate the changes.
parent aea73834
...@@ -24,7 +24,6 @@ import os ...@@ -24,7 +24,6 @@ import os
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
import torch import torch
...@@ -59,12 +58,11 @@ from diffusers.loaders import ( ...@@ -59,12 +58,11 @@ from diffusers.loaders import (
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
) )
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -658,22 +656,6 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -658,22 +656,6 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds return prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -858,33 +840,60 @@ def main(args): ...@@ -858,33 +840,60 @@ def main(args):
# => 32 layers # => 32 layers
# Set correct lora layers # Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = [] unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items(): for attn_processor_name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim # Parse the attention module.
if name.startswith("mid_block"): attn_module = unet
hidden_size = unet.config.block_out_channels[-1] for n in attn_processor_name.split(".")[:-1]:
elif name.startswith("up_blocks"): attn_module = getattr(attn_module, n)
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] # Set the `lora_layer` attribute of the attention-related matrices.
elif name.startswith("down_blocks"): attn_module.to_q.set_lora_layer(
block_id = int(name[len("down_blocks.")]) LoRALinearLayer(
hidden_size = unet.config.block_out_channels[block_id] in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): )
lora_attn_processor_class = LoRAAttnAddedKVProcessor attn_module.to_k.set_lora_layer(
else: LoRALinearLayer(
lora_attn_processor_class = ( in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor )
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
) )
module = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
) )
unet_lora_attn_procs[name] = module attn_module.to_out[0].set_lora_layer(
unet_lora_parameters.extend(module.parameters()) LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
unet.set_attn_processor(unet_lora_attn_procs) if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
attn_module.add_k_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_k_proj.in_features,
out_features=attn_module.add_k_proj.out_features,
rank=args.rank,
)
)
attn_module.add_v_proj.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.add_v_proj.in_features,
out_features=attn_module.add_v_proj.out_features,
rank=args.rank,
)
)
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
...@@ -902,7 +911,7 @@ def main(args): ...@@ -902,7 +911,7 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else: else:
...@@ -1338,7 +1347,7 @@ def main(args): ...@@ -1338,7 +1347,7 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet) unet_lora_layers = unet_lora_state_dict(unet)
if text_encoder is not None and args.train_text_encoder: if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
......
...@@ -23,7 +23,6 @@ import os ...@@ -23,7 +23,6 @@ import os
import shutil import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
import torch import torch
...@@ -51,8 +50,9 @@ from diffusers import ( ...@@ -51,8 +50,9 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -575,22 +575,6 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): ...@@ -575,22 +575,6 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
return prompt_embeds, pooled_prompt_embeds return prompt_embeds, pooled_prompt_embeds
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
return attn_processors_state_dict
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -761,29 +745,42 @@ def main(args): ...@@ -761,29 +745,42 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# Set correct lora layers # Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = [] unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items(): for attn_processor_name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim # Parse the attention module.
if name.startswith("mid_block"): attn_module = unet
hidden_size = unet.config.block_out_channels[-1] for n in attn_processor_name.split(".")[:-1]:
elif name.startswith("up_blocks"): attn_module = getattr(attn_module, n)
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] # Set the `lora_layer` attribute of the attention-related matrices.
elif name.startswith("down_blocks"): attn_module.to_q.set_lora_layer(
block_id = int(name[len("down_blocks.")]) LoRALinearLayer(
hidden_size = unet.config.block_out_channels[block_id] in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
lora_attn_processor_class = ( )
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
) )
module = lora_attn_processor_class( attn_module.to_v.set_lora_layer(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
) )
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs) # Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
...@@ -807,7 +804,7 @@ def main(args): ...@@ -807,7 +804,7 @@ def main(args):
for model in models: for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model) unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
...@@ -1274,7 +1271,7 @@ def main(args): ...@@ -1274,7 +1271,7 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet) unet_lora_layers = unet_lora_state_dict(unet)
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = accelerator.unwrap_model(text_encoder_one)
......
...@@ -50,7 +50,7 @@ from diffusers import ( ...@@ -50,7 +50,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
...@@ -609,29 +609,42 @@ def main(args): ...@@ -609,29 +609,42 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# Set correct lora layers # Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = [] unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items(): for attn_processor_name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim # Parse the attention module.
if name.startswith("mid_block"): attn_module = unet
hidden_size = unet.config.block_out_channels[-1] for n in attn_processor_name.split(".")[:-1]:
elif name.startswith("up_blocks"): attn_module = getattr(attn_module, n)
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] # Set the `lora_layer` attribute of the attention-related matrices.
elif name.startswith("down_blocks"): attn_module.to_q.set_lora_layer(
block_id = int(name[len("down_blocks.")]) LoRALinearLayer(
hidden_size = unet.config.block_out_channels[block_id] in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
lora_attn_processor_class = ( )
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
) )
module = lora_attn_processor_class( attn_module.to_v.set_lora_layer(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
) )
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs) # Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
......
...@@ -1647,7 +1647,7 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1647,7 +1647,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
"0.26.0", "0.26.0",
( (
f"Make sure use {self_cls_name[4:]} instead by setting" f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`" " `LoraLoaderMixin.load_lora_weights`"
), ),
) )
...@@ -1697,7 +1697,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -1697,7 +1697,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
"0.26.0", "0.26.0",
( (
f"Make sure use {self_cls_name[4:]} instead by setting" f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using" "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`" " `LoraLoaderMixin.load_lora_weights`"
), ),
) )
......
...@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, Optional, Union ...@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, Optional, Union
import numpy as np import numpy as np
import torch import torch
from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available from .utils import deprecate, is_transformers_available
...@@ -52,6 +53,25 @@ def compute_snr(noise_scheduler, timesteps): ...@@ -52,6 +53,25 @@ def compute_snr(noise_scheduler, timesteps):
return snr return snr
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns:
A state dict containing just the LoRA parameters.
"""
lora_state_dict = {}
for name, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
lora_layer = getattr(module, "lora_layer")
if lora_layer is not None:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel: class EMAModel:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment