Unverified Commit 1be7df02 authored by erkams's avatar erkams Committed by GitHub
Browse files

[LoRA] Freezing the model weights (#2245)



* [LoRA] Freezing the model weights

Freeze the model weights since we don't need to calculate grads for them.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 62a15cec
...@@ -415,7 +415,12 @@ def main(): ...@@ -415,7 +415,12 @@ def main():
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
) )
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment