"vscode:/vscode.git/clone" did not exist on "4fca88177e2763a2279ce713ae7779ed4122c9a1"
Unverified Commit 75aab346 authored by Pierre Dulac's avatar Pierre Dulac Committed by GitHub
Browse files

Allow users to save SDXL LoRA weights for only one text encoder (#7607)



SDXL LoRA weights for text encoders should be decoupled on save

The method checks if at least one of unet, text_encoder and
text_encoder_2 lora weights are passed, which was not reflected in the
implentation.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 35358a2d
...@@ -1406,6 +1406,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -1406,6 +1406,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers. encoder LoRA state dict because it comes from 🤗 Transformers.
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`): is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main need to call this function on all processes. In this case, set `is_main_process=True` only on the main
...@@ -1432,8 +1435,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -1432,8 +1435,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
if unet_lora_layers: if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet")) state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers: if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers( cls.write_lora_layers(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment