Unverified Commit f442955c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[lora] support loading loras from `lightx2v/Qwen-Image-Lightning` (#12119)

* feat: support qwen lightning lora.

* add docs.

* fix
parent ff9a3876
......@@ -24,6 +24,63 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## LoRA for faster inference
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
number of steps. Refer to the code snippet below:
<details>
<summary>Code</summary>
```py
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
import math
ckpt_id = "Qwen/Qwen-Image"
# From
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3), # We use shift=3 in distillation
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3), # We use shift=3 in distillation
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None, # set shift_terminal to None
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = DiffusionPipeline.from_pretrained(
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
)
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
negative_prompt = " "
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=8,
true_cfg_scale=1.0,
generator=torch.manual_seed(0),
).images[0]
image.save("qwen_fewsteps.png")
```
</details>
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
......
......@@ -2077,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
......@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
......@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
......@@ -6642,6 +6642,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment