Unverified Commit 9cc96a64 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

[FIX] Fix TypeError in DreamBooth SDXL when use_dora is False (#9879)



* fix use_dora

* fix style and quality

* fix use_dora with peft version

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5b972fbd
......@@ -67,6 +67,7 @@ from diffusers.utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_peft_version,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
......@@ -1183,26 +1184,33 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
def get_lora_config(rank, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
}
if use_dora:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
base_config["use_dora"] = True
return LoraConfig(**base_config)
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
unet.add_adapter(unet_lora_config)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment