Unverified Commit b0550a66 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] restrict certain keys to be checked for peft config update. (#10808)

* restruct certain keys to be checked for peft config update.

* updates

* finish./

* finish 2.

* updates
parent 6f74ef55
...@@ -63,6 +63,9 @@ def _maybe_adjust_config(config): ...@@ -63,6 +63,9 @@ def _maybe_adjust_config(config):
method removes the ambiguity by following what is described here: method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
""" """
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()
rank_pattern = config["rank_pattern"].copy() rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"] target_modules = config["target_modules"]
original_r = config["r"] original_r = config["r"]
...@@ -80,21 +83,22 @@ def _maybe_adjust_config(config): ...@@ -80,21 +83,22 @@ def _maybe_adjust_config(config):
ambiguous_key = key ambiguous_key = key
if exact_matches and substring_matches: if exact_matches and substring_matches:
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example) # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead # remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key] del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches: for mod in substring_matches:
# avoid overwriting if the module already has a specific rank if mod not in config["rank_pattern"] and mod not in deleted_keys:
if mod not in config["rank_pattern"]:
config["rank_pattern"][mod] = original_r config["rank_pattern"][mod] = original_r
# update the rest of the keys with the `original_r` # Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules: for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"]: if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r config["rank_pattern"][mod] = original_r
# handle alphas to deal with cases like # Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks: if has_different_ranks:
...@@ -187,6 +191,11 @@ class PeftAdapterMixin: ...@@ -187,6 +191,11 @@ class PeftAdapterMixin:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
try:
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
except ImportError:
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
...@@ -251,14 +260,22 @@ class PeftAdapterMixin: ...@@ -251,14 +260,22 @@ class PeftAdapterMixin:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension # Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1: if "lora_B" in key and val.ndim > 1:
rank[key] = val.shape[1] # Support to handle cases where layer patterns are treated as full layer names
# was added later in PEFT. So, we handle it accordingly.
# TODO: when we fix the minimal PEFT version for Diffusers,
# we should remove `_maybe_adjust_config()`.
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
else:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1: if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment