Unverified Commit 288ceebe authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[T2I LoRA training] fix: unscale fp16 gradient problem (#6119)



* fix: unscale fp16 gradient problem

* fix for dreambooth lora sdxl

* make the type-casting conditional.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9221da40
...@@ -991,6 +991,17 @@ def main(args): ...@@ -991,6 +991,17 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config) text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
......
...@@ -460,7 +460,13 @@ def main(): ...@@ -460,7 +460,13 @@ def main():
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config) unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
for param in unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
...@@ -890,13 +896,17 @@ def main(): ...@@ -890,13 +896,17 @@ def main():
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
if args.validation_prompt is not None:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.unet.load_attn_procs(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
generator = torch.Generator(device=accelerator.device) generator = torch.Generator(device=accelerator.device)
...@@ -906,7 +916,6 @@ def main(): ...@@ -906,7 +916,6 @@ def main():
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if len(images) != 0: if len(images) != 0:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment