"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e3ddbe25edeadaa5afc3f8f5bb0d645098a8b26a"
Unverified Commit 3deed729 authored by Boseong Jeon's avatar Boseong Jeon Committed by GitHub
Browse files

Handling mixed precision for dreambooth flux lora training (#9565)



Handling mixed precision and add unwarp
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent 7ffbc252
...@@ -177,7 +177,7 @@ def log_validation( ...@@ -177,7 +177,7 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
...@@ -1706,7 +1706,7 @@ def main(args): ...@@ -1706,7 +1706,7 @@ def main(args):
) )
# handle guidance # handle guidance
if transformer.config.guidance_embeds: if accelerator.unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1819,6 +1819,8 @@ def main(args): ...@@ -1819,6 +1819,8 @@ def main(args):
# create pipeline # create pipeline
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment