Unverified Commit 2868d991 authored by Tim Hinderliter's avatar Tim Hinderliter Committed by GitHub
Browse files

dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid...


dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 (#1618)

* dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16

* dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update examples/dreambooth/train_dreambooth.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0c18d02c
import argparse
import hashlib
import inspect
import itertools
import math
import os
......@@ -690,10 +691,19 @@ def main(args):
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
# When 'keep_fp32_wrapper' is `False` (the default), then the models are
# unwrapped and the mixed precision hooks are removed, so training crashes
# when the unwrapped models are used for further training.
# This is only supported in newer versions of `accelerate`.
# TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment