Unverified Commit 9edd0aa7 authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update train_dreambooth_colossalai.py

accelerator.num_processes -> gpc.get_world_size(ParallelMode.DATA)
parent f1bc2418
...@@ -484,7 +484,7 @@ def main(args): ...@@ -484,7 +484,7 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
unet = gemini_zero_dpp(unet, pg, args.placement) unet = gemini_zero_dpp(unet, pg, args.placement)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment