Commit cbbad0af authored by Patrick von Platen's avatar Patrick von Platen
Browse files

correct example

parent 00132de3
...@@ -479,7 +479,6 @@ def main(): ...@@ -479,7 +479,6 @@ def main():
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
if args.use_peft: if args.use_peft:
from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
...@@ -496,7 +495,6 @@ def main(): ...@@ -496,7 +495,6 @@ def main():
vae.requires_grad_(False) vae.requires_grad_(False)
if args.train_text_encoder: if args.train_text_encoder:
config = LoraConfig( config = LoraConfig(
r=args.lora_text_encoder_r, r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha, lora_alpha=args.lora_text_encoder_alpha,
...@@ -506,7 +504,6 @@ def main(): ...@@ -506,7 +504,6 @@ def main():
) )
text_encoder = LoraModel(config, text_encoder) text_encoder = LoraModel(config, text_encoder)
else: else:
# freeze parameters of models to save more memory # freeze parameters of models to save more memory
unet.requires_grad_(False) unet.requires_grad_(False)
vae.requires_grad_(False) vae.requires_grad_(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment