Unverified Commit cfdfcf20 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Add --vae_precision option to the SDXL pix2pix script so that we have… (#4881)



* Add --vae_precision option to the SDXL pix2pix script so that we have the option of avoiding float32 overhead

* style

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
parent e4b8e792
...@@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = { ...@@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = {
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"), "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
} }
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
...@@ -100,6 +101,16 @@ def parse_args(): ...@@ -100,6 +101,16 @@ def parse_args():
default=None, default=None,
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
) )
parser.add_argument(
"--vae_precision",
type="choice",
choices=["fp32", "fp16", "bf16"],
default="fp32",
help=(
"The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
" to this problem, and this flag allows you to use mixed precision to stabilize training."
),
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
...@@ -878,7 +889,7 @@ def main(): ...@@ -878,7 +889,7 @@ def main():
if args.pretrained_vae_model_name_or_path is not None: if args.pretrained_vae_model_name_or_path is not None:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
else: else:
vae.to(accelerator.device, dtype=torch.float32) vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment