Unverified Commit 90b94799 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA PEFT] fix LoRA loading so that correct alphas are parsed (#6225)

* initialize alpha too.

* add: test

* remove config parsing

* store rank

* debug

* remove faulty test
parent df76a39e
...@@ -827,6 +827,7 @@ def main(args): ...@@ -827,6 +827,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=args.rank, r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian", init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
) )
...@@ -835,7 +836,10 @@ def main(args): ...@@ -835,7 +836,10 @@ def main(args):
# The text encoder comes from 🤗 transformers, we will also attach adapters to it. # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
) )
text_encoder.add_adapter(text_lora_config) text_encoder.add_adapter(text_lora_config)
......
...@@ -978,7 +978,10 @@ def main(args): ...@@ -978,7 +978,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
) )
unet.add_adapter(unet_lora_config) unet.add_adapter(unet_lora_config)
...@@ -986,7 +989,10 @@ def main(args): ...@@ -986,7 +989,10 @@ def main(args):
# So, instead, we monkey-patch the forward calls of its attention-blocks. # So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
) )
text_encoder_one.add_adapter(text_lora_config) text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)
......
...@@ -452,7 +452,10 @@ def main(): ...@@ -452,7 +452,10 @@ def main():
param.requires_grad_(False) param.requires_grad_(False)
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
) )
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
......
...@@ -609,7 +609,10 @@ def main(args): ...@@ -609,7 +609,10 @@ def main(args):
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# Set correct lora layers # Set correct lora layers
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
) )
unet.add_adapter(unet_lora_config) unet.add_adapter(unet_lora_config)
...@@ -618,7 +621,10 @@ def main(args): ...@@ -618,7 +621,10 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"] r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
) )
text_encoder_one.add_adapter(text_lora_config) text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)
......
...@@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests: ...@@ -111,6 +111,7 @@ class PeftLoraLoaderMixinTests:
def get_dummy_components(self, scheduler_cls=None): def get_dummy_components(self, scheduler_cls=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
rank = 4
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs) unet = UNet2DConditionModel(**self.unet_kwargs)
...@@ -125,11 +126,14 @@ class PeftLoraLoaderMixinTests: ...@@ -125,11 +126,14 @@ class PeftLoraLoaderMixinTests:
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2") tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False r=rank,
lora_alpha=rank,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
) )
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
) )
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment