Unverified Commit 699dfb08 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

feat: support DoRA LoRA from community (#7371)

* feat: support dora loras from community

* safe-guard dora operations under peft version.

* pop use_dora when False

* make dora lora from kohya work.

* fix: kohya conversion utils.

* add a fast test for DoRA compatibility..

* add a nightly test.
parent 484c8ef3
...@@ -36,6 +36,7 @@ from ..utils import ( ...@@ -36,6 +36,7 @@ from ..utils import (
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_peft_version,
is_transformers_available, is_transformers_available,
logging, logging,
recurse_remove_peft_layers, recurse_remove_peft_layers,
...@@ -113,7 +114,7 @@ class LoraLoaderMixin: ...@@ -113,7 +114,7 @@ class LoraLoaderMixin:
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
...@@ -451,6 +452,15 @@ class LoraLoaderMixin: ...@@ -451,6 +452,15 @@ class LoraLoaderMixin:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs) lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name # adapter_name
...@@ -572,6 +582,15 @@ class LoraLoaderMixin: ...@@ -572,6 +582,15 @@ class LoraLoaderMixin:
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs) lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name # adapter_name
...@@ -654,6 +673,13 @@ class LoraLoaderMixin: ...@@ -654,6 +673,13 @@ class LoraLoaderMixin:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs) lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name # adapter_name
...@@ -1243,7 +1269,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -1243,7 +1269,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
unet_config=self.unet.config, unet_config=self.unet.config,
**kwargs, **kwargs,
) )
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import re import re
from ..utils import logging from ..utils import is_peft_version, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
te_state_dict = {} te_state_dict = {}
te2_state_dict = {} te2_state_dict = {}
network_alphas = {} network_alphas = {}
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
# every down weight has a corresponding up weight and potentially an alpha weight # every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
...@@ -198,6 +207,12 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -198,6 +207,12 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if is_unet_dora_lora:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
...@@ -229,6 +244,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ ...@@ -229,6 +244,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Rename the alphas so that they can be mapped appropriately. # Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item() alpha = state_dict.pop(lora_name_alpha).item()
......
...@@ -69,6 +69,7 @@ from .import_utils import ( ...@@ -69,6 +69,7 @@ from .import_utils import (
is_note_seq_available, is_note_seq_available,
is_onnx_available, is_onnx_available,
is_peft_available, is_peft_available,
is_peft_version,
is_scipy_available, is_scipy_available,
is_tensorboard_available, is_tensorboard_available,
is_torch_available, is_torch_available,
......
...@@ -628,6 +628,20 @@ def is_accelerate_version(operation: str, version: str): ...@@ -628,6 +628,20 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version) return compare_versions(parse(_accelerate_version), operation, version)
def is_peft_version(operation: str, version: str):
"""
Args:
Compares the current PEFT version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _peft_version:
return False
return compare_versions(parse(_peft_version), operation, version)
def is_k_diffusion_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str):
""" """
Args: Args:
......
...@@ -171,6 +171,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True ...@@ -171,6 +171,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
# layer names without the Diffusers specific # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
lora_config_kwargs = { lora_config_kwargs = {
"r": r, "r": r,
...@@ -178,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True ...@@ -178,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"rank_pattern": rank_pattern, "rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern, "alpha_pattern": alpha_pattern,
"target_modules": target_modules, "target_modules": target_modules,
"use_dora": use_dora,
} }
return lora_config_kwargs return lora_config_kwargs
......
...@@ -47,6 +47,7 @@ UNET_TO_DIFFUSERS = { ...@@ -47,6 +47,7 @@ UNET_TO_DIFFUSERS = {
".to_v_lora.up": ".to_v.lora_B", ".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B", ".lora.up": ".lora_B",
".lora.down": ".lora_A", ".lora.down": ".lora_A",
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
} }
...@@ -104,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = { ...@@ -104,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_v_lora.down": ".v_proj.lora_linear_layer.down", ".to_v_lora.down": ".v_proj.lora_linear_layer.down",
".to_out_lora.up": ".out_proj.lora_linear_layer.up", ".to_out_lora.up": ".out_proj.lora_linear_layer.up",
".to_out_lora.down": ".out_proj.lora_linear_layer.down", ".to_out_lora.down": ".out_proj.lora_linear_layer.down",
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
} }
PEFT_TO_KOHYA_SS = { PEFT_TO_KOHYA_SS = {
...@@ -315,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): ...@@ -315,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key: elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet") kohya_key = kohya_key.replace("unet", "lora_unet")
elif "lora_magnitude_vector" in kohya_key:
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight kohya_ss_state_dict[kohya_key] = weight
......
...@@ -630,3 +630,21 @@ class LoraSDXLIntegrationTests(unittest.TestCase): ...@@ -630,3 +630,21 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487]) expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice) max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
assert max_diff < 1e-3 assert max_diff < 1e-3
@nightly
def test_integration_logits_for_dora_lora(self):
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
pipeline.enable_model_cpu_offload()
images = pipeline(
"photo of ohwx dog",
num_inference_steps=10,
generator=torch.manual_seed(0),
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
assert max_diff < 1e-3
...@@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests: ...@@ -72,7 +72,7 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None unet_kwargs = None
vae_kwargs = None vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None): def get_dummy_components(self, scheduler_cls=None, use_dora=False):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4 rank = 4
...@@ -96,10 +96,15 @@ class PeftLoraLoaderMixinTests: ...@@ -96,10 +96,15 @@ class PeftLoraLoaderMixinTests:
lora_alpha=rank, lora_alpha=rank,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora,
) )
unet_lora_config = LoraConfig( unet_lora_config = LoraConfig(
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
) )
if self.has_two_text_encoders: if self.has_two_text_encoders:
...@@ -1074,6 +1079,37 @@ class PeftLoraLoaderMixinTests: ...@@ -1074,6 +1079,37 @@ class PeftLoraLoaderMixinTests:
"Fused lora should not change the output", "Fused lora should not change the output",
) )
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
"DoRA lora should change the output",
)
@unittest.skip("This is failing for now - need to investigate") @unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment