"...csrc/cpu/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1b83f46c42a062429ec0604ca8c1beae665790cd"
Unverified Commit 553b1384 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] clean up `load_lora_into_text_encoder()` and `fuse_lora()` copied from (#10495)

* factor out text encoder loading.

* make fix-copies

* remove copied from fuse_lora and unfuse_lora as needed.

* remove unused imports
parent 7bc8b923
...@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict ...@@ -28,13 +28,20 @@ from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
deprecate, deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_peft_available, is_peft_available,
is_peft_version,
is_transformers_available, is_transformers_available,
is_transformers_version,
logging, logging,
recurse_remove_peft_layers, recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
...@@ -43,6 +50,8 @@ from ..utils import ( ...@@ -43,6 +50,8 @@ from ..utils import (
if is_transformers_available(): if is_transformers_available():
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available(): if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
...@@ -297,6 +306,152 @@ def _best_guess_weight_name( ...@@ -297,6 +306,152 @@ def _best_guess_weight_name(
return weight_name return weight_name
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin: class LoraBaseMixin:
"""Utility class for handling LoRAs.""" """Utility class for handling LoRAs."""
...@@ -327,27 +482,7 @@ class LoraBaseMixin: ...@@ -327,27 +482,7 @@ class LoraBaseMixin:
tuple: tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
""" """
is_model_cpu_offload = False return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod @classmethod
def _fetch_state_dict(cls, *args, **kwargs): def _fetch_state_dict(cls, *args, **kwargs):
......
...@@ -20,20 +20,21 @@ from huggingface_hub.utils import validate_hf_hub_args ...@@ -20,20 +20,21 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate, deprecate,
get_adapter_name,
get_peft_kwargs,
is_peft_available, is_peft_available,
is_peft_version, is_peft_version,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
logging, logging,
scale_lora_layers,
) )
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa from .lora_base import ( # noqa
LORA_WEIGHT_NAME,
LORA_WEIGHT_NAME_SAFE,
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
)
from .lora_conversion_utils import ( from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers, _convert_bfl_flux_control_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers,
...@@ -55,9 +56,6 @@ if is_torch_version(">=", "1.9.0"): ...@@ -55,9 +56,6 @@ if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
if is_transformers_available():
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder" TEXT_ENCODER_NAME = "text_encoder"
...@@ -349,119 +347,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -349,119 +347,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights. weights.
""" """
if not USE_PEFT_BACKEND: _load_lora_into_text_encoder(
raise ValueError("PEFT backend is required for this method.") state_dict=state_dict,
network_alphas=network_alphas,
peft_kwargs = {} lora_scale=lora_scale,
if low_cpu_mem_usage: text_encoder=text_encoder,
if not is_peft_version(">=", "0.13.1"): prefix=prefix,
raise ValueError( text_encoder_name=cls.text_encoder_name,
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." adapter_name=adapter_name,
) _pipeline=_pipeline,
if not is_transformers_version(">", "4.45.2"): low_cpu_mem_usage=low_cpu_mem_usage,
# Note from sayakpaul: It's not in `transformers` stable yet. )
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
...@@ -892,119 +788,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -892,119 +788,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights. weights.
""" """
if not USE_PEFT_BACKEND: _load_lora_into_text_encoder(
raise ValueError("PEFT backend is required for this method.") state_dict=state_dict,
network_alphas=network_alphas,
peft_kwargs = {} lora_scale=lora_scale,
if low_cpu_mem_usage: text_encoder=text_encoder,
if not is_peft_version(">=", "0.13.1"): prefix=prefix,
raise ValueError( text_encoder_name=cls.text_encoder_name,
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." adapter_name=adapter_name,
) _pipeline=_pipeline,
if not is_transformers_version(">", "4.45.2"): low_cpu_mem_usage=low_cpu_mem_usage,
# Note from sayakpaul: It's not in `transformers` stable yet. )
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
...@@ -1401,119 +1195,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1401,119 +1195,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights. weights.
""" """
if not USE_PEFT_BACKEND: _load_lora_into_text_encoder(
raise ValueError("PEFT backend is required for this method.") state_dict=state_dict,
network_alphas=network_alphas,
peft_kwargs = {} lora_scale=lora_scale,
if low_cpu_mem_usage: text_encoder=text_encoder,
if not is_peft_version(">=", "0.13.1"): prefix=prefix,
raise ValueError( text_encoder_name=cls.text_encoder_name,
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." adapter_name=adapter_name,
) _pipeline=_pipeline,
if not is_transformers_version(">", "4.45.2"): low_cpu_mem_usage=low_cpu_mem_usage,
# Note from sayakpaul: It's not in `transformers` stable yet. )
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
...@@ -2033,119 +1725,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2033,119 +1725,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights. weights.
""" """
if not USE_PEFT_BACKEND: _load_lora_into_text_encoder(
raise ValueError("PEFT backend is required for this method.") state_dict=state_dict,
network_alphas=network_alphas,
peft_kwargs = {} lora_scale=lora_scale,
if low_cpu_mem_usage: text_encoder=text_encoder,
if not is_peft_version(">=", "0.13.1"): prefix=prefix,
raise ValueError( text_encoder_name=cls.text_encoder_name,
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." adapter_name=adapter_name,
) _pipeline=_pipeline,
if not is_transformers_version(">", "4.45.2"): low_cpu_mem_usage=low_cpu_mem_usage,
# Note from sayakpaul: It's not in `transformers` stable yet. )
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
...@@ -2204,7 +1794,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2204,7 +1794,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder"], components: List[str] = ["transformer"],
lora_scale: float = 1.0, lora_scale: float = 1.0,
safe_fusing: bool = False, safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None, adapter_names: Optional[List[str]] = None,
...@@ -2598,119 +2188,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2598,119 +2188,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights. weights.
""" """
if not USE_PEFT_BACKEND: _load_lora_into_text_encoder(
raise ValueError("PEFT backend is required for this method.") state_dict=state_dict,
network_alphas=network_alphas,
peft_kwargs = {} lora_scale=lora_scale,
if low_cpu_mem_usage: text_encoder=text_encoder,
if not is_peft_version(">=", "0.13.1"): prefix=prefix,
raise ValueError( text_encoder_name=cls.text_encoder_name,
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." adapter_name=adapter_name,
) _pipeline=_pipeline,
if not is_transformers_version(">", "4.45.2"): low_cpu_mem_usage=low_cpu_mem_usage,
# Note from sayakpaul: It's not in `transformers` stable yet. )
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
...@@ -3008,10 +2496,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3008,10 +2496,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder"], components: List[str] = ["transformer"],
lora_scale: float = 1.0, lora_scale: float = 1.0,
safe_fusing: bool = False, safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None, adapter_names: Optional[List[str]] = None,
...@@ -3052,8 +2539,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3052,8 +2539,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
...@@ -3067,9 +2553,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3067,9 +2553,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
...@@ -3316,10 +2799,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3316,10 +2799,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder"], components: List[str] = ["transformer"],
lora_scale: float = 1.0, lora_scale: float = 1.0,
safe_fusing: bool = False, safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None, adapter_names: Optional[List[str]] = None,
...@@ -3360,8 +2842,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3360,8 +2842,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
...@@ -3375,9 +2856,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3375,9 +2856,6 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
...@@ -3624,10 +3102,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3624,10 +3102,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder"], components: List[str] = ["transformer"],
lora_scale: float = 1.0, lora_scale: float = 1.0,
safe_fusing: bool = False, safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None, adapter_names: Optional[List[str]] = None,
...@@ -3668,8 +3145,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3668,8 +3145,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
...@@ -3683,9 +3159,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3683,9 +3159,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
...@@ -3932,10 +3405,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3932,10 +3405,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder"], components: List[str] = ["transformer"],
lora_scale: float = 1.0, lora_scale: float = 1.0,
safe_fusing: bool = False, safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None, adapter_names: Optional[List[str]] = None,
...@@ -3976,8 +3448,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3976,8 +3448,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
...@@ -3991,9 +3462,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3991,9 +3462,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
...@@ -4300,9 +3768,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4300,9 +3768,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
""" """
super().unfuse_lora(components=components) super().unfuse_lora(components=components)
......
...@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union ...@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
import torch.nn as nn
from ..utils import ( from ..utils import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
...@@ -30,20 +29,16 @@ from ..utils import ( ...@@ -30,20 +29,16 @@ from ..utils import (
delete_adapter_layers, delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
is_accelerate_available,
is_peft_available, is_peft_available,
is_peft_version, is_peft_version,
logging, logging,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_base import _fetch_state_dict from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = { _SET_ADAPTER_SCALE_FN_MAPPING = {
...@@ -140,27 +135,7 @@ class PeftAdapterMixin: ...@@ -140,27 +135,7 @@ class PeftAdapterMixin:
tuple: tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
""" """
is_model_cpu_offload = False return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r""" r"""
......
...@@ -21,7 +21,6 @@ import safetensors ...@@ -21,7 +21,6 @@ import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import ( from ..models.embeddings import (
ImageProjection, ImageProjection,
...@@ -44,13 +43,11 @@ from ..utils import ( ...@@ -44,13 +43,11 @@ from ..utils import (
is_torch_version, is_torch_version,
logging, logging,
) )
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -411,27 +408,7 @@ class UNet2DConditionLoadersMixin: ...@@ -411,27 +408,7 @@ class UNet2DConditionLoadersMixin:
tuple: tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
""" """
is_model_cpu_offload = False return _func_optionally_disable_offloading(_pipeline=_pipeline)
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def save_attn_procs( def save_attn_procs(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment