Unverified Commit 0c47c954 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] support non-diffusers hidream loras (#11532)

* support non-diffusers hidream loras

* make fix-copies
parent 7acf8345
...@@ -1704,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict): ...@@ -1704,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict return converted_state_dict
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
raise ValueError("Invalid LoRA state dict for HiDream.")
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
...@@ -43,6 +43,7 @@ from .lora_conversion_utils import ( ...@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_hunyuan_video_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers, _convert_non_diffusers_wan_lora_to_diffusers,
...@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
...@@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
return state_dict return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment