Unverified Commit af92869d authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[SD3 LoRA Training] Fix errors when not training text encoders (#8743)



* fix

* fix things.
Co-authored-by: default avatarLinoy Tsaban <linoy.tsaban@gmail.com>

* remove patch

* apply suggestions

---------
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avatarLinoy Tsaban <linoy.tsaban@gmail.com>
parent 0bae6e44
......@@ -962,7 +962,7 @@ def encode_prompt(
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[i],
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
......@@ -976,7 +976,7 @@ def encode_prompt(
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[:-1],
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
device=device if device is not None else text_encoders[-1].device,
)
......@@ -1491,6 +1491,9 @@ def main(args):
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
assert text_encoder_one is not None
assert text_encoder_two is not None
assert text_encoder_three is not None
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
......@@ -1598,7 +1601,7 @@ def main(args):
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
tokenizers=[None, None, None],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
......@@ -1608,7 +1611,7 @@ def main(args):
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
prompt=prompts,
prompt=args.instance_prompt,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
......@@ -1685,10 +1688,12 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = itertools.chain(
transformer_lora_parameters,
text_lora_parameters_one,
text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
params_to_clip = (
itertools.chain(
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
)
if args.train_text_encoder
else transformer_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
......@@ -1741,13 +1746,6 @@ def main(args):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
else:
text_encoder_three = text_encoder_cls_three.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_3",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
......@@ -1767,7 +1765,9 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
)
del text_encoder_one, text_encoder_two, text_encoder_three
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
torch.cuda.empty_cache()
gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment