Unverified Commit 10bee525 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] use `removeprefix` to preserve sanity. (#11493)

* use removeprefix to preserve sanity.

* f-string.
parent d88ae1f5
...@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder( ...@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None: if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
...@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder( ...@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
......
...@@ -2103,7 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2103,7 +2103,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix = prefix or cls.transformer_name prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if key.split(".")[0] == prefix: if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Find invalid keys # Find invalid keys
transformer_state_dict = transformer.state_dict() transformer_state_dict = transformer.state_dict()
...@@ -2425,7 +2425,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2425,7 +2425,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix = prefix or cls.transformer_name prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if key.split(".")[0] == prefix: if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Expand transformer parameter shapes if they don't match lora # Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False has_param_with_shape_update = False
......
...@@ -230,7 +230,7 @@ class PeftAdapterMixin: ...@@ -230,7 +230,7 @@ class PeftAdapterMixin:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.") raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if prefix is not None: if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap: if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
...@@ -261,7 +261,9 @@ class PeftAdapterMixin: ...@@ -261,7 +261,9 @@ class PeftAdapterMixin:
if network_alphas is not None and len(network_alphas) >= 1: if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs) _maybe_raise_error_for_ambiguity(lora_config_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment