Unverified Commit 6ca9c4af authored by lvzi's avatar lvzi Committed by GitHub
Browse files

fix: unscale fp16 gradient problem & potential error (#6086) (#6231)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0532cece
...@@ -640,6 +640,17 @@ def main(args): ...@@ -640,6 +640,17 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config) text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1187,6 +1198,9 @@ def main(args): ...@@ -1187,6 +1198,9 @@ def main(args):
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Final inference # Final inference
# Make sure vae.dtype is consistent with the unet.dtype
if args.mixed_precision == "fp16":
vae.to(weight_dtype)
# Load previous pipeline # Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment