Unverified Commit 2fada8dc authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[bug fix] fixes #6444 - checkpointing save issue in advanced dreambooth lora sdxl script (#6464)



* unwrap text encoder when saving hook only for full text encoder tuning

* unwrap text encoder when saving hook only for full text encoder tuning

* save embeddings in each checkpoint as well

* save embeddings in each checkpoint as well

* save embeddings in each checkpoint as well

* Update examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f2d51a28
...@@ -1316,10 +1316,12 @@ def main(args): ...@@ -1316,10 +1316,12 @@ def main(args):
if isinstance(model, type(accelerator.unwrap_model(unet))): if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
if args.train_text_encoder:
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
if args.train_text_encoder:
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
...@@ -1335,6 +1337,8 @@ def main(args): ...@@ -1335,6 +1337,8 @@ def main(args):
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
) )
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors")
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
unet_ = None unet_ = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment