Unverified Commit 2cb383f5 authored by captainzz's avatar captainzz Committed by GitHub
Browse files

fix vae dtype when accelerate config using --mixed_precision="fp16" (#9601)

* fix vae dtype when accelerate config using --mixed_precision="fp16"

* Add param for upcast vae
parent 31010ecc
......@@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument(
"--learning_rate",
type=float,
......@@ -1094,7 +1099,10 @@ def main(args):
weight_dtype = torch.bfloat16
# Move vae, transformer and text_encoder to device and cast to weight_dtype
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment