"gallery/others/plot_optical_flow.py" did not exist on "d9a69506109cb970c7568c27cbeda1f9ffb7ad70"
Unverified Commit 85c4a326 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix saving text encoder weights and kohya weights in advanced dreambooth lora script (#8766)

* update

* update

* update
parent 0bab9d6b
......@@ -1290,6 +1290,7 @@ def main(args):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
......@@ -1981,7 +1982,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,
......
......@@ -2425,7 +2425,7 @@ def main(args):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment