Unverified Commit 75aab346 authored by Pierre Dulac's avatar Pierre Dulac Committed by GitHub
Browse files

Allow users to save SDXL LoRA weights for only one text encoder (#7607)



SDXL LoRA weights for text encoders should be decoupled on save

The method checks if at least one of unet, text_encoder and
text_encoder_2 lora weights are passed, which was not reflected in the
implentation.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 35358a2d
......@@ -1406,6 +1406,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
......@@ -1432,8 +1435,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment