Unverified Commit 7a4a126d authored by PromeAI's avatar PromeAI Committed by GitHub
Browse files

fix issue that training flux controlnet was unstable and validation r… (#11373)



* fix issue that training flux controlnet was unstable and validation results were unstable

* del unused code pieces, fix grammar

---------
Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0dec414d
...@@ -8,6 +8,18 @@ Training script provided by LibAI, which is an institution dedicated to the prog ...@@ -8,6 +8,18 @@ Training script provided by LibAI, which is an institution dedicated to the prog
> >
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual. > Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
Here is a gpu memory consumption for reference, tested on a single A100 with 80G.
| period | GPU |
| - | - |
| load as float32 | ~70G |
| mv transformer and vae to bf16 | ~48G |
| pre compute txt embeddings | ~62G |
| **offload te to cpu** | ~30G |
| training | ~58G |
| validation | ~71G |
> **Gated access** > **Gated access**
> >
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login` > As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login`
...@@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \ ...@@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=16 \
--report_to="wandb" \ --report_to="wandb" \
--lr_scheduler="cosine" \
--num_double_layers=4 \ --num_double_layers=4 \
--num_single_layers=0 \ --num_single_layers=0 \
--seed=42 \ --seed=42 \
......
...@@ -148,7 +148,7 @@ def log_validation( ...@@ -148,7 +148,7 @@ def log_validation(
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
control_image=validation_image, control_image=validation_image,
num_inference_steps=28, num_inference_steps=28,
controlnet_conditioning_scale=0.7, controlnet_conditioning_scale=1,
guidance_scale=3.5, guidance_scale=3.5,
generator=generator, generator=generator,
).images[0] ).images[0]
...@@ -1085,8 +1085,6 @@ def main(args): ...@@ -1085,8 +1085,6 @@ def main(args):
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids} return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
train_dataset = get_train_dataset(args, accelerator) train_dataset = get_train_dataset(args, accelerator)
text_encoders = [text_encoder_one, text_encoder_two]
tokenizers = [tokenizer_one, tokenizer_two]
compute_embeddings_fn = functools.partial( compute_embeddings_fn = functools.partial(
compute_embeddings, compute_embeddings,
flux_controlnet_pipeline=flux_controlnet_pipeline, flux_controlnet_pipeline=flux_controlnet_pipeline,
...@@ -1103,7 +1101,8 @@ def main(args): ...@@ -1103,7 +1101,8 @@ def main(args):
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
) )
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two text_encoder_one.to("cpu")
text_encoder_two.to("cpu")
free_memory() free_memory()
# Then get the training dataset ready to be passed to the dataloader. # Then get the training dataset ready to be passed to the dataloader.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment