Unverified Commit 170833c2 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

[Fix] fp16 unscaling in train_dreambooth_lora_sdxl (#10889)



Fix fp16 bug
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent db21c970
......@@ -203,7 +203,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -213,7 +213,7 @@ def log_validation(
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment