Unverified Commit 85c4a326 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix saving text encoder weights and kohya weights in advanced dreambooth lora script (#8766)

* update

* update

* update
parent 0bab9d6b
...@@ -1290,6 +1290,7 @@ def main(args): ...@@ -1290,6 +1290,7 @@ def main(args):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model) get_peft_model_state_dict(model)
) )
else:
raise ValueError(f"unexpected save model: {model.__class__}") raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
...@@ -1981,7 +1982,7 @@ def main(args): ...@@ -1981,7 +1982,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card( save_model_card(
model_id if not args.push_to_hub else repo_id, model_id if not args.push_to_hub else repo_id,
......
...@@ -2425,7 +2425,7 @@ def main(args): ...@@ -2425,7 +2425,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card( save_model_card(
model_id if not args.push_to_hub else repo_id, model_id if not args.push_to_hub else repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment