Unverified Commit 4dcab922 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SDXL ControlNet Training] Follow-up fixes (#4188)

* hash computation. thanks to @lhoestq

* disable dtype casting.

* remove comments.
parent aed30dff
...@@ -1001,7 +1001,12 @@ def main(args): ...@@ -1001,7 +1001,12 @@ def main(args):
proportion_empty_prompts=args.proportion_empty_prompts, proportion_empty_prompts=args.proportion_empty_prompts,
) )
with accelerator.main_process_first(): with accelerator.main_process_first():
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True) from datasets.fingerprint import Hasher
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
del text_encoders, tokenizers del text_encoders, tokenizers
gc.collect() gc.collect()
...@@ -1113,8 +1118,6 @@ def main(args): ...@@ -1113,8 +1118,6 @@ def main(args):
# Convert images to latent space # Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None: if args.pretrained_vae_model_name_or_path is not None:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype) pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype)
else: else:
pixel_values = batch["pixel_values"] pixel_values = batch["pixel_values"]
latents = vae.encode(pixel_values).latent_dist.sample() latents = vae.encode(pixel_values).latent_dist.sample()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment