"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3c05b9f71c82e4cdaef579cb13f363b6c1d7964d"
Unverified Commit 97fda1b7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] feat: support non-diffusers lumina2 LoRAs. (#10909)

* feat: support non-diffusers lumina2 LoRAs.

* revert ipynb changes (but I don't know why this is required 

️)

* empty

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent cc220583
...@@ -1276,3 +1276,74 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict): ...@@ -1276,3 +1276,74 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict return converted_state_dict
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
# Remove "diffusion_model." prefix from keys.
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
converted_state_dict = {}
def get_num_layers(keys, pattern):
layers = set()
for key in keys:
match = re.search(pattern, key)
if match:
layers.add(int(match.group(1)))
return len(layers)
def process_block(prefix, index, convert_norm):
# Process attention qkv: pop lora_A and lora_B weights.
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
for attn_key in ["to_q", "to_k", "to_v"]:
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
# Process attention out weights.
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_B.weight"
)
# Process feed-forward weights for layers 1, 2, and 3.
for layer in range(1, 4):
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
)
if convert_norm:
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
)
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
for i in range(num_noise_refiner_layers):
process_block("noise_refiner", i, convert_norm=True)
context_refiner_pattern = r"context_refiner\.(\d+)\."
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
for i in range(num_context_refiner_layers):
process_block("context_refiner", i, convert_norm=False)
core_transformer_pattern = r"layers\.(\d+)\."
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
for i in range(num_core_transformer_layers):
process_block("layers", i, convert_norm=True)
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
...@@ -41,6 +41,7 @@ from .lora_conversion_utils import ( ...@@ -41,6 +41,7 @@ from .lora_conversion_utils import (
_convert_hunyuan_video_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers, _maybe_map_sgm_blocks_to_diffusers,
) )
...@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
...@@ -3909,6 +3909,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -3909,6 +3909,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
# conversion.
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
if non_diffusers:
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
return state_dict return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment