Unverified Commit 170833c2 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

[Fix] fp16 unscaling in train_dreambooth_lora_sdxl (#10889)



Fix fp16 bug
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent db21c970
...@@ -203,7 +203,7 @@ def log_validation( ...@@ -203,7 +203,7 @@ def log_validation(
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -213,7 +213,7 @@ def log_validation( ...@@ -213,7 +213,7 @@ def log_validation(
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext() autocast_ctx = nullcontext()
else: else:
autocast_ctx = torch.autocast(accelerator.device.type) autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx: with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment