Unverified Commit cef4f65c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] log a warning when there are missing keys in the LoRA loading. (#9622)



* log a warning when there are missing keys in the LoRA loading.

* handle missing keys and unexpected keys better.

* add tests

* fix-copies.

* updates

* tests

* concat warning.

* Add Differential Diffusion to Kolors (#9423)

* Added diff diff support for kolors img2img

* Fized relative imports

* Fized relative imports

* Added diff diff support for Kolors

* Fized import issues

* Added map

* Fized import issues

* Fixed naming issues

* Added diffdiff support for Kolors img2img pipeline

* Removed example docstrings

* Added map input

* Updated latents
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>

* Updated `original_with_noise`
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>

* Improved code quality

---------
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>

* FluxMultiControlNetModel (#9647)

* tests

* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* fix

---------
Co-authored-by: default avatarM Saqlain <118016760+saqlain2204@users.noreply.github.com>
Co-authored-by: default avatarÁlvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 29a2c5d1
...@@ -1358,14 +1358,30 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1358,14 +1358,30 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys: if unexpected_keys:
logger.warning( lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " if lora_unexpected_keys:
f" {unexpected_keys}. " warn_msg = (
) f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back. # Offload back.
if is_model_cpu_offload: if is_model_cpu_offload:
...@@ -1932,14 +1948,30 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1932,14 +1948,30 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys: if unexpected_keys:
logger.warning( lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " if lora_unexpected_keys:
f" {unexpected_keys}. " warn_msg = (
) f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back. # Offload back.
if is_model_cpu_offload: if is_model_cpu_offload:
...@@ -2279,14 +2311,30 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2279,14 +2311,30 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
warn_msg = ""
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys: if unexpected_keys:
logger.warning( lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " if lora_unexpected_keys:
f" {unexpected_keys}. " warn_msg = (
) f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back. # Offload back.
if is_model_cpu_offload: if is_model_cpu_offload:
...@@ -2717,14 +2765,30 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2717,14 +2765,30 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys: if unexpected_keys:
logger.warning( lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " if lora_unexpected_keys:
f" {unexpected_keys}. " warn_msg = (
) f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back. # Offload back.
if is_model_cpu_offload: if is_model_cpu_offload:
......
...@@ -354,14 +354,30 @@ class UNet2DConditionLoadersMixin: ...@@ -354,14 +354,30 @@ class UNet2DConditionLoadersMixin:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys: if unexpected_keys:
logger.warning( lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " if lora_unexpected_keys:
f" {unexpected_keys}. " warn_msg = (
) f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload return is_model_cpu_offload, is_sequential_cpu_offload
......
...@@ -27,6 +27,7 @@ from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransfo ...@@ -27,6 +27,7 @@ from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransfo
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
is_peft_available, is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend, require_peft_backend,
require_torch_gpu, require_torch_gpu,
slow, slow,
...@@ -166,7 +167,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -166,7 +167,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@slow @slow
@require_torch_gpu @require_torch_gpu
@require_peft_backend @require_peft_backend
@unittest.skip("We cannot run inference on this model with the current CI hardware") # @unittest.skip("We cannot run inference on this model with the current CI hardware")
# TODO (DN6, sayakpaul): move these tests to a beefier GPU # TODO (DN6, sayakpaul): move these tests to a beefier GPU
class FluxLoRAIntegrationTests(unittest.TestCase): class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace. """internal note: The integration slices were obtained on audace.
...@@ -208,9 +209,11 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ...@@ -208,9 +209,11 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
generator=torch.manual_seed(self.seed), generator=torch.manual_seed(self.seed),
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090]) expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
def test_flux_kohya(self): def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux") self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
...@@ -230,7 +233,9 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ...@@ -230,7 +233,9 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
def test_flux_kohya_with_text_encoder(self): def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
...@@ -248,9 +253,11 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ...@@ -248,9 +253,11 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219]) expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
def test_flux_xlabs(self): def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
...@@ -268,6 +275,8 @@ class FluxLoRAIntegrationTests(unittest.TestCase): ...@@ -268,6 +275,8 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
generator=torch.manual_seed(self.seed), generator=torch.manual_seed(self.seed),
).images ).images
out_slice = out[0, -3:, -3:, -1].flatten() out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980]) expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) assert max_diff < 1e-3
...@@ -27,8 +27,10 @@ from diffusers import ( ...@@ -27,8 +27,10 @@ from diffusers import (
LCMScheduler, LCMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger,
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_peft_version_greater, require_peft_version_greater,
...@@ -219,10 +221,18 @@ class PeftLoraLoaderMixinTests: ...@@ -219,10 +221,18 @@ class PeftLoraLoaderMixinTests:
modules_to_save = {} modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules lora_loadable_modules = self.pipeline_class._lora_loadable_modules
if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"): if (
"text_encoder" in lora_loadable_modules
and hasattr(pipe, "text_encoder")
and getattr(pipe.text_encoder, "peft_config", None) is not None
):
modules_to_save["text_encoder"] = pipe.text_encoder modules_to_save["text_encoder"] = pipe.text_encoder
if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"): if (
"text_encoder_2" in lora_loadable_modules
and hasattr(pipe, "text_encoder_2")
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2 modules_to_save["text_encoder_2"] = pipe.text_encoder_2
if has_denoiser: if has_denoiser:
...@@ -1747,6 +1757,83 @@ class PeftLoraLoaderMixinTests: ...@@ -1747,6 +1757,83 @@ class PeftLoraLoaderMixinTests:
"DoRA lora should change the output", "DoRA lora should change the output",
) )
def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
# To make things dynamic since we cannot settle with a single key for all the models where we
# offer PEFT support.
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
# Since the missing key won't contain the adapter name ("default_0").
# Also strip out the component prefix (such as "unet." from `missing_key`).
component = list({k.split(".")[0] for k in state_dict})[0]
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
pipe.unload_lora_weights()
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline")
)
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
self.assertTrue(".diffusers_cat" in cap_logger.out)
@unittest.skip("This is failing for now - need to investigate") @unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment