"vscode:/vscode.git/clone" did not exist on "2120b4eee35bcc0db5f3acd3900fb31188ed0160"
Unverified Commit 3deed729 authored by Boseong Jeon's avatar Boseong Jeon Committed by GitHub
Browse files

Handling mixed precision for dreambooth flux lora training (#9565)



Handling mixed precision and add unwarp
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 7ffbc252
......@@ -177,7 +177,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
......@@ -1706,7 +1706,7 @@ def main(args):
)
# handle guidance
if transformer.config.guidance_embeds:
if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
......@@ -1819,6 +1819,8 @@ def main(args):
# create pipeline
if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment