"docs/source/vscode:/vscode.git/clone" did not exist on "80c26877de842382664becda414fe7b8bfbe3424"
Unverified Commit 619e3ab6 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[bug fix] advanced dreambooth lora sdxl - fixes bugs described in #6486 (#6599)

* fixes bugs:
1. redundant retraction
2. param clone
3. stopping optimization of text encoder params

* param upscaling

* style
parent 9e2804f7
...@@ -1279,7 +1279,7 @@ def main(args): ...@@ -1279,7 +1279,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters(): for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name: if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32) param.data = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_one.append(param) text_lora_parameters_one.append(param)
else: else:
...@@ -1288,7 +1288,7 @@ def main(args): ...@@ -1288,7 +1288,7 @@ def main(args):
for name, param in text_encoder_two.named_parameters(): for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name: if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32) param.data = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_two.append(param) text_lora_parameters_two.append(param)
else: else:
...@@ -1725,19 +1725,19 @@ def main(args): ...@@ -1725,19 +1725,19 @@ def main(args):
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs): for epoch in range(first_epoch, args.num_train_epochs):
# if performing any kind of optimization of text_encoder params # if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti: if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder: if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch) print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params # stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params # this flag is used to reset the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0 pivoted = True
optimizer.param_groups[2]["lr"] = 0.0
else: else:
# still optimizng the text encoder # still optimizing the text encoder
text_encoder_one.train() text_encoder_one.train()
text_encoder_two.train() text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
...@@ -1747,6 +1747,12 @@ def main(args): ...@@ -1747,6 +1747,12 @@ def main(args):
unet.train() unet.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
prompts = batch["prompts"] prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
...@@ -1885,7 +1891,6 @@ def main(args): ...@@ -1885,7 +1891,6 @@ def main(args):
# every step, we reset the embeddings to the original embeddings. # every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings() embedding_handler.retract_embeddings()
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment