Unverified Commit 87ae3300 authored by Mario Namtao Shianti Larcher's avatar Mario Namtao Shianti Larcher Committed by GitHub
Browse files

[Examples] Save SDXL LoRA weights with chosen precision (#4791)

* Increase min accelerate ver to avoid OOM when mixed precision

* Rm re-instantiation of VAE

* Rm casting to float32

* Del unused models and free GPU

* Fix style
parent 1b46c661
accelerate>=0.16.0 accelerate>=0.22.0
torchvision torchvision
transformers>=4.25.1 transformers>=4.25.1
ftfy ftfy
......
...@@ -1188,14 +1188,13 @@ def main(args): ...@@ -1188,14 +1188,13 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet) unet_lora_layers = unet_attn_processors_state_dict(unet)
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
else: else:
text_encoder_lora_layers = None text_encoder_lora_layers = None
text_encoder_2_lora_layers = None text_encoder_2_lora_layers = None
...@@ -1207,14 +1206,15 @@ def main(args): ...@@ -1207,14 +1206,15 @@ def main(args):
text_encoder_2_lora_layers=text_encoder_2_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers,
) )
del unet
del text_encoder_one
del text_encoder_two
del text_encoder_lora_layers
del text_encoder_2_lora_layers
torch.cuda.empty_cache()
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment