"docs/vscode:/vscode.git/clone" did not exist on "b975bceff3558b7d93566e18f47f20862cb6b977"
Unverified Commit 7d0a47f3 authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update train_text_to_image_lora.py (#6144)



* Update train_text_to_image_lora.py

* Fix typo?

---------
Co-authored-by: default avatarM. Tolga Cangöz <46008593+standardAI@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 67b3d326
...@@ -799,7 +799,8 @@ def main(): ...@@ -799,7 +799,8 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet) unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights( StableDiffusionPipeline.save_lora_weights(
save_directory=save_path, save_directory=save_path,
...@@ -864,7 +865,8 @@ def main(): ...@@ -864,7 +865,8 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet) unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights( StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir, save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict, unet_lora_layers=unet_lora_state_dict,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment