Unverified Commit fb57c76a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] refactor lora loading at the model-level (#11719)

* factor out stuff from load_lora_adapter().

* simplifying text encoder lora loading.

* fix peft.py

* fix logging locations.

* formatting

* fix

* update

* update

* update
parent 7251bb4f
...@@ -34,7 +34,6 @@ from ..utils import ( ...@@ -34,7 +34,6 @@ from ..utils import (
delete_adapter_layers, delete_adapter_layers,
deprecate, deprecate,
get_adapter_name, get_adapter_name,
get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_peft_available, is_peft_available,
is_peft_version, is_peft_version,
...@@ -46,14 +45,13 @@ from ..utils import ( ...@@ -46,14 +45,13 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from ..utils.peft_utils import _create_lora_config
from ..utils.state_dict_utils import _load_sft_state_dict_metadata from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available(): if is_transformers_available():
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available(): if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
...@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder( ...@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
) )
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes. # their prefixes.
...@@ -377,60 +373,25 @@ def _load_lora_into_text_encoder( ...@@ -377,60 +373,25 @@ def _load_lora_into_text_encoder(
# convert state dict # convert state dict
state_dict = convert_state_dict_to_peft(state_dict) state_dict = convert_state_dict_to_peft(state_dict)
for name, _ in text_encoder_attn_modules(text_encoder): for name, _ in text_encoder.named_modules():
for module in ("out_proj", "q_proj", "k_proj", "v_proj"): if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.{module}.lora_B.weight" rank_key = f"{name}.lora_B.weight"
if rank_key not in state_dict: if rank_key in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1] rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
if metadata is not None: # create `LoraConfig`
lora_config_kwargs = metadata lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(text_encoder) adapter_name = get_adapter_name(text_encoder)
# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not # in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter( text_encoder.load_adapter(
...@@ -442,7 +403,6 @@ def _load_lora_into_text_encoder( ...@@ -442,7 +403,6 @@ def _load_lora_into_text_encoder(
# scale LoRA layers with `lora_scale` # scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale) scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back. # Offload back.
...@@ -453,10 +413,11 @@ def _load_lora_into_text_encoder( ...@@ -453,10 +413,11 @@ def _load_lora_into_text_encoder(
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
model_class_name = text_encoder.__class__.__name__
logger.warning( logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any " "This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new" "https://github.com/huggingface/diffusers/issues/new"
) )
......
...@@ -29,13 +29,13 @@ from ..utils import ( ...@@ -29,13 +29,13 @@ from ..utils import (
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs,
is_peft_available, is_peft_available,
is_peft_version, is_peft_version,
logging, logging,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales from .unet_loader_utils import _maybe_expand_lora_scales
...@@ -64,26 +64,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -64,26 +64,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
} }
def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
if exact_matches and substring_matches:
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)
class PeftAdapterMixin: class PeftAdapterMixin:
""" """
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
...@@ -191,7 +171,7 @@ class PeftAdapterMixin: ...@@ -191,7 +171,7 @@ class PeftAdapterMixin:
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`. initialize `LoraConfig`.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
...@@ -216,7 +196,6 @@ class PeftAdapterMixin: ...@@ -216,7 +196,6 @@ class PeftAdapterMixin:
) )
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
...@@ -275,38 +254,8 @@ class PeftAdapterMixin: ...@@ -275,38 +254,8 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
} }
if metadata is not None: # create LoraConfig
lora_config_kwargs = metadata lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
...@@ -317,9 +266,8 @@ class PeftAdapterMixin: ...@@ -317,9 +266,8 @@ class PeftAdapterMixin:
# Now we remove any existing hooks to `_pipeline`. # Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error # otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {} peft_kwargs = {}
if is_peft_version(">=", "0.13.1"): if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
...@@ -403,30 +351,7 @@ class PeftAdapterMixin: ...@@ -403,30 +351,7 @@ class PeftAdapterMixin:
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}") logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise raise
warn_msg = "" _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back. # Offload back.
if is_model_cpu_offload: if is_model_cpu_offload:
...@@ -436,10 +361,11 @@ class PeftAdapterMixin: ...@@ -436,10 +361,11 @@ class PeftAdapterMixin:
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict: if prefix is not None and not state_dict:
model_class_name = self.__class__.__name__
logger.warning( logger.warning(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any " "This is safe to ignore if LoRA state dict didn't originally have any "
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: " "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new" "https://github.com/huggingface/diffusers/issues/new"
) )
......
...@@ -21,9 +21,12 @@ from typing import Optional ...@@ -21,9 +21,12 @@ from typing import Optional
from packaging import version from packaging import version
from .import_utils import is_peft_available, is_torch_available from . import logging
from .import_utils import is_peft_available, is_peft_version, is_torch_available
logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None: ...@@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None:
f"The version of PEFT you are using is not compatible, please use a version that is greater" f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}" f" than {min_version}"
) )
def _create_lora_config(
state_dict,
network_alphas,
metadata,
rank_pattern_dict,
is_unet: bool = True,
):
from peft import LoraConfig
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
)
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
# Version checks for DoRA and lora_bias
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
try:
return LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
def _maybe_raise_error_for_ambiguous_keys(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
if exact_matches and substring_matches:
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
...@@ -1794,7 +1794,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1794,7 +1794,7 @@ class PeftLoraLoaderMixinTests:
missing_key = [k for k in state_dict if "lora_A" in k][0] missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key] del state_dict[missing_key]
logger = logging.get_logger("diffusers.loaders.peft") logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict) pipe.load_lora_weights(state_dict)
...@@ -1829,7 +1829,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1829,7 +1829,7 @@ class PeftLoraLoaderMixinTests:
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
logger = logging.get_logger("diffusers.loaders.peft") logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict) pipe.load_lora_weights(state_dict)
...@@ -2006,9 +2006,6 @@ class PeftLoraLoaderMixinTests: ...@@ -2006,9 +2006,6 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False denoiser_lora_config.lora_bias = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment