"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "60d259add1113a98a856f58936f1689ac4717c26"
Unverified Commit 26149c0e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)



* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 0703ce88
...@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder( ...@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with. # Load the layers corresponding to text encoder and make necessary adjustments.
if any(text_encoder_name in key for key in keys): if prefix is not None:
# Load the layers corresponding to text encoder and make necessary adjustments. state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = { if len(state_dict) > 0:
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys logger.info(f"Loading {prefix}.")
} rank = {}
state_dict = convert_state_dict_to_diffusers(state_dict)
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
if len(text_encoder_lora_state_dict) > 0: lora_config = LoraConfig(**lora_config_kwargs)
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs) # adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
# adapter_name is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# inject LoRA layers and load the state dict # scale LoRA layers with `lora_scale`
# in transformers we automatically check whether the adapter name is already in use or not scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale` text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
# Offload back. if prefix is not None and not state_dict:
if is_model_cpu_offload: logger.info(
_pipeline.enable_model_cpu_offload() f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
elif is_sequential_cpu_offload: )
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _func_optionally_disable_offloading(_pipeline): def _func_optionally_disable_offloading(_pipeline):
......
...@@ -298,19 +298,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -298,19 +298,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) logger.info(f"Loading {cls.unet_name}.")
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) unet.load_lora_adapter(
if not only_text_encoder: state_dict,
# Load the layers corresponding to UNet. prefix=cls.unet_name,
logger.info(f"Loading {cls.unet_name}.") network_alphas=network_alphas,
unet.load_lora_adapter( adapter_name=adapter_name,
state_dict, _pipeline=_pipeline,
prefix=cls.unet_name, low_cpu_mem_usage=low_cpu_mem_usage,
network_alphas=network_alphas, )
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
def load_lora_into_text_encoder( def load_lora_into_text_encoder(
...@@ -559,31 +555,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -559,31 +555,26 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} self.load_lora_into_text_encoder(
if len(text_encoder_state_dict) > 0: state_dict,
self.load_lora_into_text_encoder( network_alphas=network_alphas,
text_encoder_state_dict, text_encoder=self.text_encoder,
network_alphas=network_alphas, prefix=self.text_encoder_name,
text_encoder=self.text_encoder, lora_scale=self.lora_scale,
prefix="text_encoder", adapter_name=adapter_name,
lora_scale=self.lora_scale, _pipeline=self,
adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self, )
low_cpu_mem_usage=low_cpu_mem_usage, self.load_lora_into_text_encoder(
) state_dict,
network_alphas=network_alphas,
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} text_encoder=self.text_encoder_2,
if len(text_encoder_2_state_dict) > 0: prefix=f"{self.text_encoder_name}_2",
self.load_lora_into_text_encoder( lora_scale=self.lora_scale,
text_encoder_2_state_dict, adapter_name=adapter_name,
network_alphas=network_alphas, _pipeline=self,
text_encoder=self.text_encoder_2, low_cpu_mem_usage=low_cpu_mem_usage,
prefix="text_encoder_2", )
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
...@@ -738,19 +729,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -738,19 +729,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) logger.info(f"Loading {cls.unet_name}.")
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) unet.load_lora_adapter(
if not only_text_encoder: state_dict,
# Load the layers corresponding to UNet. prefix=cls.unet_name,
logger.info(f"Loading {cls.unet_name}.") network_alphas=network_alphas,
unet.load_lora_adapter( adapter_name=adapter_name,
state_dict, _pipeline=_pipeline,
prefix=cls.unet_name, low_cpu_mem_usage=low_cpu_mem_usage,
network_alphas=network_alphas, )
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
...@@ -1085,43 +1072,33 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1085,43 +1072,33 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} self.load_lora_into_transformer(
if len(transformer_state_dict) > 0: state_dict,
self.load_lora_into_transformer( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict, adapter_name=adapter_name,
transformer=getattr(self, self.transformer_name) _pipeline=self,
if not hasattr(self, "transformer") low_cpu_mem_usage=low_cpu_mem_usage,
else self.transformer, )
adapter_name=adapter_name, self.load_lora_into_text_encoder(
_pipeline=self, state_dict,
low_cpu_mem_usage=low_cpu_mem_usage, network_alphas=None,
) text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} lora_scale=self.lora_scale,
if len(text_encoder_state_dict) > 0: adapter_name=adapter_name,
self.load_lora_into_text_encoder( _pipeline=self,
text_encoder_state_dict, low_cpu_mem_usage=low_cpu_mem_usage,
network_alphas=None, )
text_encoder=self.text_encoder, self.load_lora_into_text_encoder(
prefix="text_encoder", state_dict,
lora_scale=self.lora_scale, network_alphas=None,
adapter_name=adapter_name, text_encoder=self.text_encoder_2,
_pipeline=self, prefix=f"{self.text_encoder_name}_2",
low_cpu_mem_usage=low_cpu_mem_usage, lora_scale=self.lora_scale,
) adapter_name=adapter_name,
_pipeline=self,
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} low_cpu_mem_usage=low_cpu_mem_usage,
if len(text_encoder_2_state_dict) > 0: )
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=None,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
def load_lora_into_transformer( def load_lora_into_transformer(
...@@ -1541,18 +1518,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1541,18 +1518,23 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
transformer_lora_state_dict = { transformer_lora_state_dict = {
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k k: state_dict.get(k)
for k in list(state_dict.keys())
if k.startswith(f"{self.transformer_name}.") and "lora" in k
} }
transformer_norm_state_dict = { transformer_norm_state_dict = {
k: state_dict.pop(k) k: state_dict.pop(k)
for k in list(state_dict.keys()) for k in list(state_dict.keys())
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) if k.startswith(f"{self.transformer_name}.")
and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
} }
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( has_param_with_expanded_shape = False
transformer, transformer_lora_state_dict, transformer_norm_state_dict if len(transformer_lora_state_dict) > 0:
) has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
transformer, transformer_lora_state_dict, transformer_norm_state_dict
)
if has_param_with_expanded_shape: if has_param_with_expanded_shape:
logger.info( logger.info(
...@@ -1560,19 +1542,21 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1560,19 +1542,21 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
"To get a comprehensive list of parameter names that were modified, enable debug logging." "To get a comprehensive list of parameter names that were modified, enable debug logging."
) )
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer=transformer, lora_state_dict=transformer_lora_state_dict
)
if len(transformer_lora_state_dict) > 0: if len(transformer_lora_state_dict) > 0:
self.load_lora_into_transformer( transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
transformer_lora_state_dict, transformer=transformer, lora_state_dict=transformer_lora_state_dict
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
for k in transformer_lora_state_dict:
state_dict.update({k: transformer_lora_state_dict[k]})
self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if len(transformer_norm_state_dict) > 0: if len(transformer_norm_state_dict) > 0:
transformer._transformer_norm_layers = self._load_norm_into_transformer( transformer._transformer_norm_layers = self._load_norm_into_transformer(
...@@ -1581,18 +1565,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1581,18 +1565,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
discard_original_layers=False, discard_original_layers=False,
) )
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} self.load_lora_into_text_encoder(
if len(text_encoder_state_dict) > 0: state_dict,
self.load_lora_into_text_encoder( network_alphas=network_alphas,
text_encoder_state_dict, text_encoder=self.text_encoder,
network_alphas=network_alphas, prefix=self.text_encoder_name,
text_encoder=self.text_encoder, lora_scale=self.lora_scale,
prefix="text_encoder", adapter_name=adapter_name,
lora_scale=self.lora_scale, _pipeline=self,
adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self, )
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
def load_lora_into_transformer( def load_lora_into_transformer(
...@@ -1625,17 +1607,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1625,17 +1607,14 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) )
# Load the layers corresponding to transformer. # Load the layers corresponding to transformer.
keys = list(state_dict.keys()) logger.info(f"Loading {cls.transformer_name}.")
transformer_present = any(key.startswith(cls.transformer_name) for key in keys) transformer.load_lora_adapter(
if transformer_present: state_dict,
logger.info(f"Loading {cls.transformer_name}.") network_alphas=network_alphas,
transformer.load_lora_adapter( adapter_name=adapter_name,
state_dict, _pipeline=_pipeline,
network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, )
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
def _load_norm_into_transformer( def _load_norm_into_transformer(
...@@ -2174,17 +2153,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2174,17 +2153,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
) )
# Load the layers corresponding to transformer. # Load the layers corresponding to transformer.
keys = list(state_dict.keys()) logger.info(f"Loading {cls.transformer_name}.")
transformer_present = any(key.startswith(cls.transformer_name) for key in keys) transformer.load_lora_adapter(
if transformer_present: state_dict,
logger.info(f"Loading {cls.transformer_name}.") network_alphas=network_alphas,
transformer.load_lora_adapter( adapter_name=adapter_name,
state_dict, _pipeline=_pipeline,
network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, )
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
......
...@@ -235,10 +235,7 @@ class PeftAdapterMixin: ...@@ -235,10 +235,7 @@ class PeftAdapterMixin:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.") raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if prefix is not None: if prefix is not None:
keys = list(state_dict.keys()) state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
if len(state_dict) > 0: if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}): if adapter_name in getattr(self, "peft_config", {}):
...@@ -355,6 +352,11 @@ class PeftAdapterMixin: ...@@ -355,6 +352,11 @@ class PeftAdapterMixin:
_pipeline.enable_sequential_cpu_offload() _pipeline.enable_sequential_cpu_offload()
# Unsafe code /> # Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)
def save_lora_adapter( def save_lora_adapter(
self, self,
save_directory, save_directory,
......
...@@ -371,9 +371,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -371,9 +371,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
cap_logger.out.startswith( "The provided state dict contains normalization layers in addition to LoRA layers"
"The provided state dict contains normalization layers in addition to LoRA layers" in cap_logger.out
)
) )
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
...@@ -392,7 +391,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -392,7 +391,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.load_lora_weights(norm_state_dict) pipe.load_lora_weights(norm_state_dict)
self.assertTrue( self.assertTrue(
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
) )
def test_lora_parameter_expanded_shapes(self): def test_lora_parameter_expanded_shapes(self):
......
...@@ -1948,6 +1948,50 @@ class PeftLoraLoaderMixinTests: ...@@ -1948,6 +1948,50 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0] _ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(no_op_state_dict)
out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
# test only for text encoder
for lora_module in self.pipeline_class._lora_loadable_modules:
if "text_encoder" in lora_module:
text_encoder = getattr(pipe, lora_module)
if lora_module == "text_encoder":
prefix = "text_encoder"
elif lora_module == "text_encoder_2":
prefix = "text_encoder_2"
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
self.pipeline_class.load_lora_into_text_encoder(
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
)
self.assertTrue(
cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
)
def test_set_adapters_match_attention_kwargs(self): def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match.""" """Test to check if outputs after `set_adapters()` and attention kwargs match."""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment