"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5bacc2f5af7904749827fac15df05df0cb782fe4"
Unverified Commit 15ed53d2 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fixes LoRA SDXL training script with DDP + PEFT (#6816)

Update train_dreambooth_lora_sdxl.py
parent 9cc59ba0
...@@ -1399,8 +1399,8 @@ def main(args): ...@@ -1399,8 +1399,8 @@ def main(args):
text_encoder_two.train() text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
text_encoder_one.text_model.embeddings.requires_grad_(True) accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True) accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment