"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "208fa3683de06712b70a4b4773199cb39cd66523"
Unverified Commit acf479bd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced flux training] bug fix + reduce memory cost as in #9829 (#9838)

* memory improvement as done here: https://github.com/huggingface/diffusers/pull/9829



* fix bug

* fix bug

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 03bf77c4
...@@ -2154,6 +2154,7 @@ def main(args): ...@@ -2154,6 +2154,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
elems_to_repeat = 1
if freeze_text_encoder: if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers prompts, text_encoders, tokenizers
...@@ -2168,17 +2169,21 @@ def main(args): ...@@ -2168,17 +2169,21 @@ def main(args):
max_sequence_length=args.max_sequence_length, max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5, add_special_tokens=add_special_tokens_t5,
) )
else:
elems_to_repeat = len(prompts)
if not freeze_text_encoder: if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two], text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None], tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length, max_sequence_length=args.max_sequence_length,
device=accelerator.device, device=accelerator.device,
prompt=prompts, prompt=prompts,
) )
# Convert images to latent space # Convert images to latent space
if args.cache_latents: if args.cache_latents:
model_input = latents_cache[step].sample() model_input = latents_cache[step].sample()
...@@ -2371,6 +2376,9 @@ def main(args): ...@@ -2371,6 +2376,9 @@ def main(args):
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
images = None
del pipeline
if freeze_text_encoder: if freeze_text_encoder:
del text_encoder_one, text_encoder_two del text_encoder_one, text_encoder_two
free_memory() free_memory()
...@@ -2448,6 +2456,8 @@ def main(args): ...@@ -2448,6 +2456,8 @@ def main(args):
commit_message="End of training", commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"], ignore_patterns=["step_*", "epoch_*"],
) )
images = None
del pipeline
accelerator.end_training() accelerator.end_training()
......
...@@ -1648,11 +1648,15 @@ def main(args): ...@@ -1648,11 +1648,15 @@ def main(args):
prompt=prompts, prompt=prompts,
) )
else: else:
elems_to_repeat = len(prompts)
if args.train_text_encoder: if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two], text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None], tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length, max_sequence_length=args.max_sequence_length,
device=accelerator.device, device=accelerator.device,
prompt=args.instance_prompt, prompt=args.instance_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment