Unverified Commit 02eeb8e7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Handle DoRA better (#9547)

* handle dora.

* print test

* debug

* fix

* fix-copies

* update logits

* add warning in the test.

* make is_dora check consistent.

* fix-copies
parent 66eef9a6
...@@ -99,7 +99,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -99,7 +99,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
...@@ -211,6 +211,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -211,6 +211,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
network_alphas = None network_alphas = None
# TODO: replace it with a method from `state_dict_utils` # TODO: replace it with a method from `state_dict_utils`
...@@ -562,7 +567,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -562,7 +567,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
unet_config=self.unet.config, unet_config=self.unet.config,
**kwargs, **kwargs,
) )
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
...@@ -684,6 +690,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -684,6 +690,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
network_alphas = None network_alphas = None
# TODO: replace it with a method from `state_dict_utils` # TODO: replace it with a method from `state_dict_utils`
...@@ -1089,6 +1100,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1089,6 +1100,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict return state_dict
def load_lora_weights( def load_lora_weights(
...@@ -1125,7 +1142,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1125,7 +1142,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
...@@ -1587,9 +1604,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1587,9 +1604,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict) is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya: if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
...@@ -1659,7 +1680,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1659,7 +1680,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
) )
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
...@@ -2374,6 +2395,12 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2374,6 +2395,12 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict return state_dict
def load_lora_weights( def load_lora_weights(
...@@ -2405,7 +2432,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2405,7 +2432,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
......
...@@ -33,8 +33,10 @@ from diffusers import ( ...@@ -33,8 +33,10 @@ from diffusers import (
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
T2IAdapter, T2IAdapter,
) )
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger,
load_image, load_image,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -620,14 +622,18 @@ class LoraSDXLIntegrationTests(unittest.TestCase): ...@@ -620,14 +622,18 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya") pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
pipeline.enable_model_cpu_offload() pipeline.enable_model_cpu_offload()
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
images = pipeline( images = pipeline(
"photo of ohwx dog", "photo of ohwx dog",
num_inference_steps=10, num_inference_steps=10,
generator=torch.manual_seed(0), generator=torch.manual_seed(0),
output_type="np", output_type="np",
).images ).images
assert "It seems like you are using a DoRA checkpoint" in cap_logger.out
predicted_slice = images[0, -3:, -3:, -1].flatten() predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516]) expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886])
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
assert max_diff < 1e-3 assert max_diff < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment