Unverified Commit acf479bd authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[advanced flux training] bug fix + reduce memory cost as in #9829 (#9838)

* memory improvement as done here: https://github.com/huggingface/diffusers/pull/9829



* fix bug

* fix bug

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 03bf77c4
......@@ -2154,6 +2154,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
elems_to_repeat = 1
if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
......@@ -2168,17 +2169,21 @@ def main(args):
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
else:
elems_to_repeat = len(prompts)
if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
......@@ -2371,6 +2376,9 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
......@@ -2448,6 +2456,8 @@ def main(args):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
accelerator.end_training()
......
......@@ -1648,11 +1648,15 @@ def main(args):
prompt=prompts,
)
else:
elems_to_repeat = len(prompts)
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment