Unverified Commit ffed2420 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

fix distributed init twice (#2252)

fix colossalai dreambooth
parent 8178c840
...@@ -161,12 +161,6 @@ def parse_args(input_args=None): ...@@ -161,12 +161,6 @@ def parse_args(input_args=None):
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
) )
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument( parser.add_argument(
"--gradient_checkpointing", "--gradient_checkpointing",
action="store_true", action="store_true",
...@@ -376,10 +370,8 @@ def main(args): ...@@ -376,10 +370,8 @@ def main(args):
else: else:
colossalai.launch_from_torch(config={}, seed=args.seed) colossalai.launch_from_torch(config={}, seed=args.seed)
colossalai.launch_from_torch(config={}) local_rank = gpc.get_local_rank(ParallelMode.DATA)
world_size = gpc.get_world_size(ParallelMode.DATA)
if args.seed is not None:
gpc.set_seed(args.seed)
if args.with_prior_preservation: if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir) class_images_dir = Path(args.class_data_dir)
...@@ -408,7 +400,7 @@ def main(args): ...@@ -408,7 +400,7 @@ def main(args):
for example in tqdm( for example in tqdm(
sample_dataloader, sample_dataloader,
desc="Generating class images", desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, disable=not local_rank == 0,
): ):
images = pipeline(example["prompt"]).images images = pipeline(example["prompt"]).images
...@@ -420,7 +412,7 @@ def main(args): ...@@ -420,7 +412,7 @@ def main(args):
del pipeline del pipeline
# Handle the repository creation # Handle the repository creation
if gpc.get_local_rank(ParallelMode.DATA) == 0: if local_rank == 0:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
...@@ -486,12 +478,7 @@ def main(args): ...@@ -486,12 +478,7 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.scale_lr: if args.scale_lr:
args.learning_rate = ( args.learning_rate = args.learning_rate * args.train_batch_size * world_size
args.learning_rate
* args.gradient_accumulation_steps
* args.train_batch_size
* gpc.get_world_size(ParallelMode.DATA)
)
unet = gemini_zero_dpp(unet, args.placement) unet = gemini_zero_dpp(unet, args.placement)
...@@ -547,7 +534,7 @@ def main(args): ...@@ -547,7 +534,7 @@ def main(args):
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if args.max_train_steps is None: if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True overrode_max_train_steps = True
...@@ -555,8 +542,8 @@ def main(args): ...@@ -555,8 +542,8 @@ def main(args):
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps,
) )
weight_dtype = torch.float32 weight_dtype = torch.float32
if args.mixed_precision == "fp16": if args.mixed_precision == "fp16":
...@@ -571,14 +558,14 @@ def main(args): ...@@ -571,14 +558,14 @@ def main(args):
text_encoder.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if overrode_max_train_steps: if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train! # Train!
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * world_size
logger.info("***** Running training *****", ranks=[0]) logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
...@@ -586,11 +573,10 @@ def main(args): ...@@ -586,11 +573,10 @@ def main(args):
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0) progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
progress_bar.set_description("Steps") progress_bar.set_description("Steps")
global_step = 0 global_step = 0
...@@ -607,7 +593,7 @@ def main(args): ...@@ -607,7 +593,7 @@ def main(args):
optimizer.zero_grad() optimizer.zero_grad()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
...@@ -667,7 +653,7 @@ def main(args): ...@@ -667,7 +653,7 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet) torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=torch_unet, unet=torch_unet,
...@@ -682,7 +668,7 @@ def main(args): ...@@ -682,7 +668,7 @@ def main(args):
torch.cuda.synchronize() torch.cuda.synchronize()
unet = get_static_torch_model(unet) unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=unet, unet=unet,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment