Unverified Commit 4623f095 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[DreamBooth] Set train mode for text encoder (#1012)

Set train mode for text encoder
parent abe05822
...@@ -574,6 +574,8 @@ def main(args): ...@@ -574,6 +574,8 @@ def main(args):
for epoch in range(args.num_train_epochs): for epoch in range(args.num_train_epochs):
unet.train() unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
# Convert images to latent space # Convert images to latent space
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment