Unverified Commit 4a98d6e0 authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update train_text_to_image_lora.py (#2795)

parent b94880e5
...@@ -542,9 +542,9 @@ def main(): ...@@ -542,9 +542,9 @@ def main():
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
# Move unet, vae and text_encoder to device and cast to weight_dtype # Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype) if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment