Unverified Commit 35b81fff authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Wuerstchen] fix fp16 training and correct lora args (#6245)



fix fp16 training
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e0d8c910
...@@ -527,9 +527,17 @@ def main(): ...@@ -527,9 +527,17 @@ def main():
# lora attn processor # lora attn processor
prior_lora_config = LoraConfig( prior_lora_config = LoraConfig(
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"] r=args.rank,
lora_alpha=args.rank,
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
) )
# Add adapter and make sure the trainable params are in float32.
prior.add_adapter(prior_lora_config) prior.add_adapter(prior_lora_config)
if args.mixed_precision == "fp16":
for param in prior.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment