Unverified Commit 2d753b6f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

fix train_dreambooth_lora_sd3.py loading hook (#9107)

parent 39e1f7ea
...@@ -1271,7 +1271,7 @@ def main(args): ...@@ -1271,7 +1271,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = { transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
} }
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment