Unverified Commit d87ce2ce authored by CyberVy's avatar CyberVy Committed by GitHub
Browse files

Fix missing **kwargs in lora_pipeline.py (#11011)



* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 36d0553a
...@@ -452,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -452,7 +452,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
...@@ -473,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -473,7 +477,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
...@@ -892,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -892,7 +896,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
...@@ -913,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -913,7 +921,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class SD3LoraLoaderMixin(LoraBaseMixin): class SD3LoraLoaderMixin(LoraBaseMixin):
...@@ -1291,7 +1299,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1291,7 +1299,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
...@@ -1313,7 +1325,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1313,7 +1325,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class FluxLoraLoaderMixin(LoraBaseMixin): class FluxLoraLoaderMixin(LoraBaseMixin):
...@@ -1829,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1829,7 +1841,11 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) )
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
...@@ -1850,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1850,7 +1866,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
# We override this here account for `_transformer_norm_layers` and `_overwritten_params`. # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
def unload_lora_weights(self, reset_to_overwritten_params=False): def unload_lora_weights(self, reset_to_overwritten_params=False):
...@@ -2549,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2549,7 +2565,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
...@@ -2567,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2567,7 +2587,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class Mochi1LoraLoaderMixin(LoraBaseMixin): class Mochi1LoraLoaderMixin(LoraBaseMixin):
...@@ -2853,7 +2873,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -2853,7 +2873,11 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -2872,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -2872,7 +2896,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class LTXVideoLoraLoaderMixin(LoraBaseMixin): class LTXVideoLoraLoaderMixin(LoraBaseMixin):
...@@ -3158,7 +3182,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3158,7 +3182,11 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -3177,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3177,7 +3205,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin): class SanaLoraLoaderMixin(LoraBaseMixin):
...@@ -3463,7 +3491,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3463,7 +3491,11 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -3482,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3482,7 +3514,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
...@@ -3771,7 +3803,11 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3771,7 +3803,11 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -3790,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3790,7 +3826,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class Lumina2LoraLoaderMixin(LoraBaseMixin): class Lumina2LoraLoaderMixin(LoraBaseMixin):
...@@ -4080,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4080,7 +4116,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
...@@ -4099,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4099,7 +4139,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin): class WanLoraLoaderMixin(LoraBaseMixin):
...@@ -4386,7 +4426,11 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4386,7 +4426,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -4405,7 +4449,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4405,7 +4449,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class CogView4LoraLoaderMixin(LoraBaseMixin): class CogView4LoraLoaderMixin(LoraBaseMixin):
...@@ -4691,7 +4735,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -4691,7 +4735,11 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
``` ```
""" """
super().fuse_lora( super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
...@@ -4710,7 +4758,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -4710,7 +4758,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment