Unverified Commit 0267c523 authored by HelloWorldBeginner's avatar HelloWorldBeginner Committed by GitHub
Browse files

fix bugs when using deepspeed in sdxl (#7917)



fix bugs when using deepspeed
Co-authored-by: default avatarmhh001 <mahonghao1@huawei.com>
parent be4afa0b
...@@ -35,7 +35,7 @@ import torch.utils.checkpoint ...@@ -35,7 +35,7 @@ import torch.utils.checkpoint
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
...@@ -742,7 +742,8 @@ def main(args): ...@@ -742,7 +742,8 @@ def main(args):
model.save_pretrained(os.path.join(output_dir, "unet")) model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
weights.pop() if weights:
weights.pop()
def load_model_hook(models, input_dir): def load_model_hook(models, input_dir):
if args.use_ema: if args.use_ema:
...@@ -914,7 +915,7 @@ def main(args): ...@@ -914,7 +915,7 @@ def main(args):
train_dataset_with_vae = train_dataset.map( train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn, compute_vae_encodings_fn,
batched=True, batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, batch_size=args.train_batch_size,
new_fingerprint=new_fingerprint_for_vae, new_fingerprint=new_fingerprint_for_vae,
) )
precomputed_dataset = concatenate_datasets( precomputed_dataset = concatenate_datasets(
...@@ -1160,7 +1161,8 @@ def main(args): ...@@ -1160,7 +1161,8 @@ def main(args):
accelerator.log({"train_loss": train_loss}, step=global_step) accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0 train_loss = 0.0
if accelerator.is_main_process: # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None: if args.checkpoints_total_limit is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment