Unverified Commit 23a2cd33 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] training fix the position of param casting when loading them (#8460)

fix the position of param casting when loading them
parent 4edde134
...@@ -1289,8 +1289,8 @@ def main(args): ...@@ -1289,8 +1289,8 @@ def main(args):
models = [unet_] models = [unet_]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
......
...@@ -1363,8 +1363,8 @@ def main(args): ...@@ -1363,8 +1363,8 @@ def main(args):
models = [unet_] models = [unet_]
if args.train_text_encoder: if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_]) models.extend([text_encoder_one_, text_encoder_two_])
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params(models) cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment