Unverified Commit 619e3ab6 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[bug fix] advanced dreambooth lora sdxl - fixes bugs described in #6486 (#6599)

* fixes bugs:
1. redundant retraction
2. param clone
3. stopping optimization of text encoder params

* param upscaling

* style
parent 9e2804f7
......@@ -1279,7 +1279,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
......@@ -1288,7 +1288,7 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
......@@ -1725,19 +1725,19 @@ def main(args):
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs):
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
# this flag is used to reset the optimizer to optimize only on unet params
pivoted = True
else:
# still optimizng the text encoder
# still optimizing the text encoder
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
......@@ -1747,6 +1747,12 @@ def main(args):
unet.train()
for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
with accelerator.accumulate(unet):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
......@@ -1885,8 +1891,7 @@ def main(args):
# every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti:
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings()
embedding_handler.retract_embeddings()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment