"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b954c22a446301c644120f41c046a4d4c4553c7a"
Unverified Commit 006d0927 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux LoRA] fix for prior preservation and mixed precision sampling, follow up on #11873 (#12264)

* propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script

* propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script

* propagate fixes from https://github.com/huggingface/diffusers/pull/11873/

 to flux script

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 9e4a75b1
...@@ -1399,6 +1399,7 @@ def main(args): ...@@ -1399,6 +1399,7 @@ def main(args):
torch_dtype = torch.float16 torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16": elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -1419,7 +1420,8 @@ def main(args): ...@@ -1419,7 +1420,8 @@ def main(args):
for example in tqdm( for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
): ):
images = pipeline(example["prompt"]).images with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images): for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
......
...@@ -1131,6 +1131,7 @@ def main(args): ...@@ -1131,6 +1131,7 @@ def main(args):
torch_dtype = torch.float16 torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16": elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -1151,7 +1152,8 @@ def main(args): ...@@ -1151,7 +1152,8 @@ def main(args):
for example in tqdm( for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
): ):
images = pipeline(example["prompt"]).images with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images): for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
...@@ -1159,8 +1161,7 @@ def main(args): ...@@ -1159,8 +1161,7 @@ def main(args):
image.save(image_filename) image.save(image_filename)
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1728,6 +1729,10 @@ def main(args): ...@@ -1728,6 +1729,10 @@ def main(args):
device=accelerator.device, device=accelerator.device,
prompt=args.instance_prompt, prompt=args.instance_prompt,
) )
else:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
# Convert images to latent space # Convert images to latent space
if args.cache_latents: if args.cache_latents:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment