Unverified Commit 1fddee21 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Improve copied from comments in the LoRA loader classes (#10995)

* more sanity of mind with copied from ...

* better

* better
parent b38450d5
...@@ -843,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -843,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
) )
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
...@@ -1210,10 +1210,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1210,10 +1210,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights( def save_lora_weights(
cls, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True, is_main_process: bool = True,
...@@ -1262,7 +1263,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1262,7 +1263,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
...@@ -1272,6 +1272,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1272,6 +1272,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
...@@ -1315,6 +1316,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1315,6 +1316,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
...@@ -1328,7 +1330,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1328,7 +1330,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`): unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
...@@ -2833,6 +2835,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -2833,6 +2835,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
...@@ -2876,6 +2879,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -2876,6 +2879,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
...@@ -3136,6 +3140,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3136,6 +3140,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
...@@ -3179,6 +3184,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3179,6 +3184,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
...@@ -3439,6 +3445,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3439,6 +3445,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
...@@ -3482,6 +3489,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3482,6 +3489,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
...@@ -3745,6 +3753,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3745,6 +3753,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
...@@ -3788,6 +3797,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3788,6 +3797,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment