Unverified Commit 908e5e9c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix some bad comment in training scripts (#3798)

* relax tolerance slightly

* correct incorrect naming
parent 27150793
...@@ -1092,8 +1092,8 @@ def main(args): ...@@ -1092,8 +1092,8 @@ def main(args):
unet, optimizer, train_dataloader, lr_scheduler unet, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
......
...@@ -790,8 +790,8 @@ def main(args): ...@@ -790,8 +790,8 @@ def main(args):
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
unet.requires_grad_(False) unet.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
......
...@@ -747,8 +747,8 @@ def main(): ...@@ -747,8 +747,8 @@ def main():
if args.use_ema: if args.use_ema:
ema_unet.to(accelerator.device) ema_unet.to(accelerator.device)
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
......
...@@ -430,8 +430,8 @@ def main(): ...@@ -430,8 +430,8 @@ def main():
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
......
...@@ -752,8 +752,8 @@ def main(): ...@@ -752,8 +752,8 @@ def main():
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast the unet and vae weights to half-precision # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16 weight_dtype = torch.float16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment