Unverified Commit 2cb383f5 authored by captainzz's avatar captainzz Committed by GitHub
Browse files

fix vae dtype when accelerate config using --mixed_precision="fp16" (#9601)

* fix vae dtype when accelerate config using --mixed_precision="fp16"

* Add param for upcast vae
parent 31010ecc
...@@ -357,6 +357,11 @@ def parse_args(input_args=None): ...@@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
) )
parser.add_argument(
"--upcast_vae",
action="store_true",
help="Whether or not to upcast vae to fp32",
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
...@@ -1094,7 +1099,10 @@ def main(args): ...@@ -1094,7 +1099,10 @@ def main(args):
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
# Move vae, transformer and text_encoder to device and cast to weight_dtype # Move vae, transformer and text_encoder to device and cast to weight_dtype
if args.upcast_vae:
vae.to(accelerator.device, dtype=torch.float32) vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment