Unverified Commit 589faa8c authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update train_text_to_image_lora.py (#2464)

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py
parent 39a3c77e
...@@ -448,19 +448,6 @@ def main(): ...@@ -448,19 +448,6 @@ def main():
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# now we will add new LoRA weights to the attention layers # now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes # It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables: # The sizes of the attention layers consist only of two different variables:
...@@ -492,6 +479,20 @@ def main(): ...@@ -492,6 +479,20 @@ def main():
) )
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment