"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "5bac2ced5b669586c9d294782e78f93d9f5df90a"
Commit c6e31fa0 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

feature: qwen-image support cpu offload (block) and refactor transfomer model (#260)

* feature: qwen-image support cpu offload (block) and refactor transfomrer model

* bugfix
parent 3652b385
{
"seed": 42,
"batchsize": 1,
"_comment": "格式: '宽高比': [width, height]",
"aspect_ratios": {
"1:1": [1328, 1328],
"16:9": [1664, 928],
"9:16": [928, 1664],
"4:3": [1472, 1140],
"3:4": [142, 184]
},
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching",
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34,
"_auto_resize": false,
"cpu_offload": true,
"offload_granularity": "block"
}
import torch
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImageTransformerInfer
class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
def __init__(self, config, blocks):
super().__init__(config, blocks)
self.phases_num = 3
if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
self.offload_ratio = self.config["offload_ratio"]
else:
self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_blocks_offload
else:
assert NotImplementedError
elif offload_granularity == "phase":
assert NotImplementedError
else:
assert NotImplementedError
if offload_granularity != "model":
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=len(self.blocks), offload_ratio=self.offload_ratio, phases_num=self.phases_num)
else:
assert NotImplementedError
def infer_with_blocks_offload(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs):
for block_idx in range(len(self.blocks)):
self.block_idx = block_idx
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = self.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < len(self.blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, self.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
encoder_hidden_states, hidden_states = self.infer_block(
block=self.blocks[block_idx],
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
self.weights_stream_mgr.swap_weights()
return encoder_hidden_states, hidden_states
import torch
class QwenImagePostInfer: class QwenImagePostInfer:
def __init__(self, config, norm_out, proj_out): def __init__(self, config, norm_out, proj_out):
self.config = config self.config = config
self.norm_out = norm_out self.norm_out = norm_out
self.proj_out = proj_out self.proj_out = proj_out
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.init_cpu_offload()
def init_cpu_offload(self):
self.norm_out = self.norm_out.to(torch.device("cuda"))
self.proj_out = self.proj_out.to(torch.device("cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
......
...@@ -10,6 +10,16 @@ class QwenImagePreInfer: ...@@ -10,6 +10,16 @@ class QwenImagePreInfer:
self.time_text_embed = time_text_embed self.time_text_embed = time_text_embed
self.pos_embed = pos_embed self.pos_embed = pos_embed
self.attention_kwargs = {} self.attention_kwargs = {}
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.init_cpu_offload()
def init_cpu_offload(self):
self.img_in = self.img_in.to(torch.device("cuda"))
self.txt_norm = self.txt_norm.to(torch.device("cuda"))
self.txt_in = self.txt_in.to(torch.device("cuda"))
self.time_text_embed = self.time_text_embed.to(torch.device("cuda"))
self.pos_embed = self.pos_embed.to(torch.device("cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
......
...@@ -9,6 +9,7 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -9,6 +9,7 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
self.blocks = blocks self.blocks = blocks
self.infer_conditional = True self.infer_conditional = True
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_func = self.infer_calculating
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -87,5 +88,5 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -87,5 +88,5 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
def infer(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out, attention_kwargs): def infer(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out, attention_kwargs):
_, temb, image_rotary_emb = pre_infer_out _, temb, image_rotary_emb = pre_infer_out
encoder_hidden_states, hidden_states = self.infer_calculating(hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs) encoder_hidden_states, hidden_states = self.infer_func(hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from diffusers.utils import deprecate
from diffusers.utils.import_utils import is_torch_npu_available, is_torch_version
from torch import nn
if is_torch_npu_available():
import torch_npu
ACT2CLS = {
"swish": nn.SiLU,
"silu": nn.SiLU,
"mish": nn.Mish,
"gelu": nn.GELU,
"relu": nn.ReLU,
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACT2CLS:
return ACT2CLS[act_fn]()
else:
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
class FP32SiLU(nn.Module):
r"""
SiLU activation function with input upcasted to torch.float32.
"""
def __init__(self):
super().__init__()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = self.proj(hidden_states)
if is_torch_npu_available():
# using torch_npu.npu_geglu can run faster and save memory on NPU.
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class SwiGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function. It's similar to
`GEGLU` but uses SiLU / Swish instead of GeLU.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.activation(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://huggingface.co/papers/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class LinearActivation(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.activation = get_activation(activation)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
return self.activation(hidden_states)
This diff is collapsed.
This diff is collapsed.
from typing import Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class DefaultLinear(nn.Linear):
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
def replace_linear_with_custom(model: nn.Module, CustomLinear: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
custom_linear = CustomLinear(in_features=in_features, out_features=out_features, bias=bias)
with torch.no_grad():
custom_linear.weight.copy_(module.weight)
if bias:
custom_linear.bias.copy_(module.bias)
setattr(model, name, custom_linear)
else:
replace_linear_with_custom(module, CustomLinear)
return model
...@@ -4,26 +4,24 @@ import os ...@@ -4,26 +4,24 @@ import os
import torch import torch
try: try:
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel
except ImportError: except ImportError:
QwenImageTransformer2DModel = None QwenImageTransformer2DModel = None
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
from .infer.post_infer import QwenImagePostInfer from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer from .infer.transformer_infer import QwenImageTransformerInfer
from .layers.linear import DefaultLinear, replace_linear_with_custom from .transformer_qwenimage import QwenImageTransformer2DModel
from .layers.normalization import DefaultLayerNorm, DefaultRMSNorm, replace_layernorm_with_custom, replace_rmsnorm_with_custom
class QwenImageTransformerModel: class QwenImageTransformerModel:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.transformer = QwenImageTransformer2DModel.from_pretrained(os.path.join(config.model_path, "transformer")) self.transformer = QwenImageTransformer2DModel.from_pretrained(os.path.join(config.model_path, "transformer"))
# repalce linear & normalization self.cpu_offload = config.get("cpu_offload", False)
self.transformer = replace_linear_with_custom(self.transformer, DefaultLinear) self.target_device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
self.transformer = replace_layernorm_with_custom(self.transformer, DefaultLayerNorm) self.transformer.to(self.target_device).to(torch.bfloat16)
self.transformer = replace_rmsnorm_with_custom(self.transformer, DefaultRMSNorm)
self.transformer.to(torch.device("cuda")).to(torch.bfloat16)
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f: with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
transformer_config = json.load(f) transformer_config = json.load(f)
...@@ -38,7 +36,7 @@ class QwenImageTransformerModel: ...@@ -38,7 +36,7 @@ class QwenImageTransformerModel:
def _init_infer_class(self): def _init_infer_class(self):
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = QwenImageTransformerInfer self.transformer_infer_class = QwenImageTransformerInfer if not self.cpu_offload else QwenImageOffloadTransformerInfer
else: else:
assert NotImplementedError assert NotImplementedError
self.pre_infer_class = QwenImagePreInfer self.pre_infer_class = QwenImagePreInfer
......
This diff is collapsed.
...@@ -60,6 +60,8 @@ class QwenImageRunner(DefaultRunner): ...@@ -60,6 +60,8 @@ class QwenImageRunner(DefaultRunner):
self.load_model() self.load_model()
elif self.config.get("lazy_load", False): elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False) assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i": if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_t2i self.run_input_encoder = self._run_input_encoder_local_t2i
elif self.config["task"] == "i2i": elif self.config["task"] == "i2i":
...@@ -67,6 +69,18 @@ class QwenImageRunner(DefaultRunner): ...@@ -67,6 +69,18 @@ class QwenImageRunner(DefaultRunner):
else: else:
assert NotImplementedError assert NotImplementedError
@ProfilingContext("Run DiT")
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run(total_steps)
self.end_run()
return latents, generator
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
def _run_input_encoder_local_t2i(self): def _run_input_encoder_local_t2i(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
...@@ -110,6 +124,28 @@ class QwenImageRunner(DefaultRunner): ...@@ -110,6 +124,28 @@ class QwenImageRunner(DefaultRunner):
image_latents = self.vae.encode_vae_image(image) image_latents = self.vae.encode_vae_image(image)
return {"image_latents": image_latents} return {"image_latents": image_latents}
def run(self, total_steps=None):
from lightx2v.utils.profiler import ProfilingContext4Debug
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
return self.model.scheduler.latents, self.model.scheduler.generator
def set_target_shape(self): def set_target_shape(self):
if not self.config._auto_resize: if not self.config._auto_resize:
width, height = self.config.aspect_ratios[self.config.aspect_ratio] width, height = self.config.aspect_ratios[self.config.aspect_ratio]
...@@ -154,7 +190,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -154,7 +190,7 @@ class QwenImageRunner(DefaultRunner):
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
def run_vae_decoder(self, latents): def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae() self.vae_decoder = self.load_vae()
images = self.vae.decode(latents) images = self.vae.decode(latents)
...@@ -172,7 +208,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -172,7 +208,7 @@ class QwenImageRunner(DefaultRunner):
self.set_target_shape() self.set_target_shape()
latents, generator = self.run_dit() latents, generator = self.run_dit()
images = self.run_vae_decoder(latents) images = self.run_vae_decoder(latents, generator)
image = images[0] image = images[0]
image.save(f"{self.config.save_video_path}") image.save(f"{self.config.save_video_path}")
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls qwen_image \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_t2i_block.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \
--save_video_path ${lightx2v_path}/save_results/qwen_image_t2i.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment