"vscode:/vscode.git/clone" did not exist on "5ecac15adeed3cc3452167f507bc9fa773608c35"
Unverified Commit a889bb7e authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: support lora in qwen-image-edit (#570)

parent 13bba9df
{
"batchsize": 1,
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 8,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching",
"transformer_in_channels": 64,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"_auto_resize": true,
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"CONDITION_IMAGE_SIZE": 1048576,
"USE_IMAGE_ID_IN_PROMPT": false,
"lora_configs": [
{
"path": "/path/to/Qwen-Image-Edit-Lightning-4steps-V1.0.safetensors",
"strength": 1.0
}
]
}
import os
import torch
from loguru import logger
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
def fuse_lora_weights(original_weight, lora_down, lora_up, alpha):
rank = lora_down.shape[0]
lora_delta = torch.mm(lora_up, lora_down) # W_up × W_down
scaling = alpha / rank
lora_delta = lora_delta * scaling
fused_weight = original_weight + lora_delta
return fused_weight
class QwenImageLoraWrapper:
def __init__(self, qwenimage_model):
self.model = qwenimage_model
self.lora_metadata = {}
self.device = torch.device(AI_DEVICE) if not self.model.config.get("cpu_offload", False) else torch.device("cpu")
def load_lora(self, lora_path, lora_name=None):
if lora_name is None:
lora_name = os.path.basename(lora_path).split(".")[0]
if lora_name in self.lora_metadata:
logger.info(f"LoRA {lora_name} already loaded, skipping...")
return lora_name
self.lora_metadata[lora_name] = {"path": lora_path}
logger.info(f"Registered LoRA metadata for: {lora_name} from {lora_path}")
return lora_name
def _load_lora_file(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).to(self.device) for key in f.keys()}
return tensor_dict
def apply_lora(self, lora_name, alpha=1.0):
if lora_name not in self.lora_metadata:
logger.info(f"LoRA {lora_name} not found. Please load it first.")
if not hasattr(self.model, "original_weight_dict"):
logger.error("Model does not have 'original_weight_dict'. Cannot apply LoRA.")
return False
lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
weight_dict = self.model.original_weight_dict
weight_dict = self._apply_lora_weights(weight_dict, lora_weights, alpha)
self.model._apply_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights
return True
@torch.no_grad()
def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_prefixs = [
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_add_out",
"attn.to_out.0",
"img_mlp.net.0.proj",
"txt_mlp.net.0.proj",
"txt_mlp.net.2",
]
for prefix in lora_prefixs:
for idx in range(self.model.config["num_layers"]):
prefix_name = f"transformer_blocks.{idx}.{prefix}"
lora_up = lora_weights[f"{prefix_name}.lora_up.weight"]
lora_down = lora_weights[f"{prefix_name}.lora_down.weight"]
lora_alpha = lora_weights[f"{prefix_name}.alpha"]
origin = weight_dict[f"{prefix_name}.weight"]
weight_dict[f"{prefix_name}.weight"] = fuse_lora_weights(origin, lora_down, lora_up, lora_alpha)
return weight_dict
......@@ -7,6 +7,7 @@ from PIL import Image
from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.lora_adapter import QwenImageLoraWrapper
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
......@@ -46,6 +47,15 @@ class QwenImageRunner(DefaultRunner):
def load_transformer(self):
model = QwenImageTransformerModel(self.config)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False)
lora_wrapper = QwenImageLoraWrapper(model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return model
def load_text_encoder(self):
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export CUDA_VISIBLE_DEVICES=
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
#!/bin/bash
# set path and first
export lightx2v_path=
export model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_lora.json \
--prompt "Change the person to a standing position, bending over to hold the dog's front paws." \
--negative_prompt " " \
--image_path qwen_image_edit/qwen_edit1.webp \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \
--seed 0
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls qwen_image \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment