Commit a40ffb3f authored by Watebear's avatar Watebear Committed by GitHub
Browse files

refactor qwen-image (#297)

parent 701075f4
{
"seed": 42,
"batchsize": 1,
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"num_laysers": 60,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching",
"transformer_in_channels": 64,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"_auto_resize": true,
"cpu_offload": true,
"offload_granularity": "block",
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
}
...@@ -57,5 +57,15 @@ ...@@ -57,5 +57,15 @@
"prompt_template_encode_start_idx": 34, "prompt_template_encode_start_idx": 34,
"_auto_resize": false, "_auto_resize": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block" "offload_granularity": "block",
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
} }
...@@ -47,5 +47,15 @@ ...@@ -47,5 +47,15 @@
"transformer_in_channels": 64, "transformer_in_channels": 64,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64, "prompt_template_encode_start_idx": 64,
"_auto_resize": true "_auto_resize": true,
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
} }
...@@ -55,5 +55,15 @@ ...@@ -55,5 +55,15 @@
"feature_caching": "NoCaching", "feature_caching": "NoCaching",
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34, "prompt_template_encode_start_idx": 34,
"_auto_resize": false "_auto_resize": false,
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
"axes_dims_rope": [
16,
56,
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
} }
...@@ -51,7 +51,7 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -51,7 +51,7 @@ class SageAttn2Weight(AttnWeightTemplate):
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_moe_distill"]: elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_moe_distill", "qwen_image"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
...@@ -119,3 +119,20 @@ class RMSWeightSgl(RMSWeight): ...@@ -119,3 +119,20 @@ class RMSWeightSgl(RMSWeight):
input_tensor = input_tensor * self.weight input_tensor = input_tensor * self.weight
return input_tensor return input_tensor
@RMS_WEIGHT_REGISTER("fp32_variance")
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor):
input_dtype = input_tensor.dtype
variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
import gc
import math import math
import os import os
...@@ -45,20 +46,26 @@ def calculate_dimensions(target_area, ratio): ...@@ -45,20 +46,26 @@ def calculate_dimensions(target_area, ratio):
class Qwen25_VLForConditionalGeneration_TextEncoder: class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.device("cuda")).to(torch.bfloat16)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"))
self.tokenizer_max_length = 1024 self.tokenizer_max_length = 1024
self.prompt_template_encode = config.prompt_template_encode self.prompt_template_encode = config.prompt_template_encode
self.prompt_template_encode_start_idx = config.prompt_template_encode_start_idx self.prompt_template_encode_start_idx = config.prompt_template_encode_start_idx
if config.task == "i2i": self.cpu_offload = config.get("cpu_offload", False)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2) if self.cpu_offload:
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(config.model_path, "processor")) self.device = torch.device("cpu")
else:
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.load()
def load(self):
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config.model_path, "text_encoder")).to(self.device).to(self.dtype)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config.model_path, "tokenizer"))
if self.config.task == "i2i":
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config.model_path, "processor"))
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool() bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1) valid_lengths = bool_mask.sum(dim=1)
...@@ -92,7 +99,11 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -92,7 +99,11 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
image = image.unsqueeze(2) image = image.unsqueeze(2)
return prompt_image, image, (image_height, image_width) return prompt_image, image, (image_height, image_width)
@torch.no_grad()
def infer(self, text, image=None): def infer(self, text, image=None):
if self.cpu_offload:
self.text_encoder.to(torch.device("cuda"))
template = self.prompt_template_encode template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in text] txt = [template.format(e) for e in text]
...@@ -104,7 +115,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -104,7 +115,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=prompt_image, images=prompt_image,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(self.device) ).to(torch.device("cuda"))
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
...@@ -114,7 +125,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -114,7 +125,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
) )
else: else:
prompt_image, image, image_info = None, None, None prompt_image, image, image_info = None, None, None
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device) model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(torch.device("cuda"))
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
...@@ -129,7 +140,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -129,7 +140,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device) prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=torch.device("cuda"))
prompt_embeds_mask = encoder_attention_mask prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -137,4 +148,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -137,4 +148,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len) prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len)
if self.cpu_offload:
self.text_encoder.to(torch.device("cpu"))
torch.cuda.empty_cache()
gc.collect()
return prompt_embeds, prompt_embeds_mask, image, image_info return prompt_embeds, prompt_embeds_mask, image, image_info
...@@ -5,9 +5,10 @@ from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImag ...@@ -5,9 +5,10 @@ from lightx2v.models.networks.qwen_image.infer.transformer_infer import QwenImag
class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer): class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
def __init__(self, config, blocks): def __init__(self, config):
super().__init__(config, blocks) super().__init__(config)
self.phases_num = 3 self.phases_num = 3
self.num_blocks = config["num_layers"]
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
self.offload_ratio = self.config["offload_ratio"] self.offload_ratio = self.config["offload_ratio"]
...@@ -19,36 +20,28 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer): ...@@ -19,36 +20,28 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
self.infer_func = self.infer_with_blocks_offload self.infer_func = self.infer_with_blocks_offload
else: else:
assert NotImplementedError assert NotImplementedError
elif offload_granularity == "phase":
assert NotImplementedError
else: else:
assert NotImplementedError assert NotImplementedError
if offload_granularity != "model": if offload_granularity != "model":
self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=len(self.blocks), offload_ratio=self.offload_ratio, phases_num=self.phases_num) self.weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.num_blocks, offload_ratio=self.offload_ratio, phases_num=self.phases_num)
else: else:
assert NotImplementedError assert NotImplementedError
def infer_with_blocks_offload(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs): def infer_with_blocks_offload(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
for block_idx in range(len(self.blocks)): for block_idx in range(self.num_blocks):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = self.blocks[0] self.weights_stream_mgr.active_weights[0] = block_weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda() self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < len(self.blocks) - 1: if block_idx < self.num_blocks - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, self.blocks) self.weights_stream_mgr.prefetch_weights(block_idx + 1, block_weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
encoder_hidden_states, hidden_states = self.infer_block( encoder_hidden_states, hidden_states = self.infer_block(
block=self.blocks[block_idx], block_weight=block_weights.blocks[block_idx], hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
) )
self.weights_stream_mgr.swap_weights()
self.weights_stream_mgr.swap_weights()
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
import torch import torch
import torch.nn.functional as F
class QwenImagePostInfer: class QwenImagePostInfer:
def __init__(self, config, norm_out, proj_out): def __init__(self, config):
self.config = config self.config = config
self.norm_out = norm_out
self.proj_out = proj_out
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.init_cpu_offload()
def init_cpu_offload(self):
self.norm_out = self.norm_out.to(torch.device("cuda"))
self.proj_out = self.proj_out.to(torch.device("cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, hidden_states, temb): def infer(self, weights, hidden_states, temb):
hidden_states = self.norm_out(hidden_states, temb) temb1 = F.silu(temb)
output = self.proj_out(hidden_states) temb1 = weights.norm_out_linear.apply(temb1)
return output scale, shift = torch.chunk(temb1, 2, dim=1)
hidden_states = weights.norm_out.apply(hidden_states) * (1 + scale) + shift
output = weights.proj_out_linear.apply(hidden_states.squeeze(0))
return output.unsqueeze(0)
import functools
import math
from typing import List
import torch
from torch import nn
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int = 256,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1000,
max_period: int = 10000,
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class QwenEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.rope_cache = {}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
else:
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = txt_seq_lens
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenImagePreInfer: class QwenImagePreInfer:
def __init__(self, config, img_in, txt_norm, txt_in, time_text_embed, pos_embed): def __init__(self, config):
self.config = config self.config = config
self.img_in = img_in
self.txt_norm = txt_norm
self.txt_in = txt_in
self.time_text_embed = time_text_embed
self.pos_embed = pos_embed
self.attention_kwargs = {} self.attention_kwargs = {}
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload: self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config.axes_dims_rope), scale_rope=True)
self.init_cpu_offload()
def init_cpu_offload(self):
self.img_in = self.img_in.to(torch.device("cuda"))
self.txt_norm = self.txt_norm.to(torch.device("cuda"))
self.txt_in = self.txt_in.to(torch.device("cuda"))
self.time_text_embed = self.time_text_embed.to(torch.device("cuda"))
self.pos_embed = self.pos_embed.to(torch.device("cuda"))
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs): def infer(self, weights, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs):
hidden_states_0 = hidden_states hidden_states = hidden_states.squeeze(0)
hidden_states = self.img_in(hidden_states) hidden_states = weights.img_in.apply(hidden_states)
timestep = timestep.to(hidden_states.dtype) timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.squeeze(0)
encoder_hidden_states = self.txt_in(encoder_hidden_states) encoder_hidden_states = weights.txt_norm.apply(encoder_hidden_states)
encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states)
timesteps_proj = get_timestep_embedding(timestep).to(torch.bfloat16)
if guidance is not None: embed = weights.time_text_embed_timestep_embedder_linear_1.apply(timesteps_proj)
guidance = guidance.to(hidden_states.dtype) * 1000 embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_text_embed_timestep_embedder_linear_2.apply(embed0)
temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states) image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens[0], device=hidden_states.device)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (hidden_states_0, temb, image_rotary_emb) return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (embed0, image_rotary_emb)
from typing import Tuple, Union
import torch import torch
import torch.nn.functional as F
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
def apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
def calculate_q_k_len(q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k
def apply_attn(block_weight, hidden_states, encoder_hidden_states, image_rotary_emb, attn_type):
seq_txt = encoder_hidden_states.shape[1]
# Compute QKV for image stream (sample projections)
img_query = block_weight.attn.to_q.apply(hidden_states[0])
img_key = block_weight.attn.to_k.apply(hidden_states[0])
img_value = block_weight.attn.to_v.apply(hidden_states[0])
# Compute QKV for text stream (context projections)
txt_query = block_weight.attn.add_q_proj.apply(encoder_hidden_states[0])
txt_key = block_weight.attn.add_k_proj.apply(encoder_hidden_states[0])
txt_value = block_weight.attn.add_v_proj.apply(encoder_hidden_states[0])
# Reshape for multi-head attention
img_query = img_query.unflatten(-1, (block_weight.attn.heads, -1))
img_key = img_key.unflatten(-1, (block_weight.attn.heads, -1))
img_value = img_value.unflatten(-1, (block_weight.attn.heads, -1))
txt_query = txt_query.unflatten(-1, (block_weight.attn.heads, -1))
txt_key = txt_key.unflatten(-1, (block_weight.attn.heads, -1))
txt_value = txt_value.unflatten(-1, (block_weight.attn.heads, -1))
# Apply QK normalization
if block_weight.attn.norm_q is not None:
img_query = block_weight.attn.norm_q.apply(img_query)
if block_weight.attn.norm_k is not None:
img_key = block_weight.attn.norm_k.apply(img_key)
if block_weight.attn.norm_added_q is not None:
txt_query = block_weight.attn.norm_added_q.apply(txt_query)
if block_weight.attn.norm_added_k is not None:
txt_key = block_weight.attn.norm_added_k.apply(txt_key)
# Apply RoPE
if image_rotary_emb is not None:
img_freqs, txt_freqs1 = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query.unsqueeze(0), img_freqs, use_real=False)
img_key = apply_rotary_emb_qwen(img_key.unsqueeze(0), img_freqs, use_real=False)
txt_query = apply_rotary_emb_qwen(txt_query.unsqueeze(0), txt_freqs1, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key.unsqueeze(0), txt_freqs1, use_real=False)
# Concatenate for joint attention
# Order: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value.unsqueeze(0), img_value.unsqueeze(0)], dim=1)
# Compute joint attention
if attn_type == "torch_sdpa":
joint_hidden_states = block_weight.attn.calculate.apply(q=joint_query, k=joint_key, v=joint_value)
elif attn_type in ["flash_attn3", "sage_attn2"]:
joint_query = joint_query.squeeze(0)
joint_key = joint_key.squeeze(0)
joint_value = joint_value.squeeze(0)
k_lens = torch.tensor([joint_key.size(0)], dtype=torch.int32, device=joint_key.device)
cu_seqlens_q, cu_seqlens_k = calculate_q_k_len(joint_query, k_lens=k_lens)
joint_hidden_states = block_weight.attn.calculate.apply(
q=joint_query, k=joint_key, v=joint_value, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=joint_query.size(0), max_seqlen_kv=joint_key.size(0), model_cls="qwen_image"
)
# Split attention outputs back
txt_attn_output = joint_hidden_states[:seq_txt, :] # Text part
img_attn_output = joint_hidden_states[seq_txt:, :] # Image part
# Apply output projections
img_attn_output = block_weight.attn.to_out.apply(img_attn_output)
txt_attn_output = block_weight.attn.to_add_out.apply(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerInfer(BaseTransformerInfer): class QwenImageTransformerInfer(BaseTransformerInfer):
def __init__(self, config, blocks): def __init__(self, config):
self.config = config self.config = config
self.blocks = blocks
self.infer_conditional = True self.infer_conditional = True
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_func = self.infer_calculating self.infer_func = self.infer_calculating
self.attn_type = config.get("attn_type", "flash_attn3")
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer_block(self, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, joint_attention_kwargs): def _modulate(self, x, mod_params):
"""Apply modulation to input tensor"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def infer_block(self, block_weight, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
# Get modulation parameters for both streams # Get modulation parameters for both streams
img_mod_params = block.img_mod(temb) # [B, 6*dim] img_mod_params = block_weight.img_mod.apply(F.silu(temb))
txt_mod_params = block.txt_mod(temb) # [B, 6*dim] txt_mod_params = block_weight.txt_mod.apply(F.silu(temb))
# Split modulation parameters for norm1 and norm2 # Split modulation parameters for norm1 and norm2
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
# Process image stream - norm1 + modulation # Process image stream - norm1 + modulation
img_normed = block.img_norm1(hidden_states) img_normed = block_weight.img_norm1.apply(hidden_states)
img_modulated, img_gate1 = block._modulate(img_normed, img_mod1) img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation # Process text stream - norm1 + modulation
txt_normed = block.txt_norm1(encoder_hidden_states) txt_normed = block_weight.txt_norm1.apply(encoder_hidden_states)
txt_modulated, txt_gate1 = block._modulate(txt_normed, txt_mod1) txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation # Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic: # This directly implements the DoubleStreamLayerMegatron logic:
...@@ -37,13 +173,12 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -37,13 +173,12 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
# 2. Applies QK normalization and RoPE # 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention # 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams # 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {} attn_output = apply_attn(
attn_output = block.attn( block_weight=block_weight,
hidden_states=img_modulated, # Image stream (will be processed as "sample") hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs, attn_type=self.attn_type,
) )
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
...@@ -54,15 +189,17 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -54,15 +189,17 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP # Process image stream - norm2 + MLP
img_normed2 = block.img_norm2(hidden_states) img_normed2 = block_weight.img_norm2.apply(hidden_states)
img_modulated2, img_gate2 = block._modulate(img_normed2, img_mod2) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = block.img_mlp(img_modulated2) img_mlp_output = F.silu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0)))
img_mlp_output = block_weight.img_mlp.mlp_2.apply(img_mlp_output)
hidden_states = hidden_states + img_gate2 * img_mlp_output hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP # Process text stream - norm2 + MLP
txt_normed2 = block.txt_norm2(encoder_hidden_states) txt_normed2 = block_weight.txt_norm2.apply(encoder_hidden_states)
txt_modulated2, txt_gate2 = block._modulate(txt_normed2, txt_mod2) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = block.txt_mlp(txt_modulated2) txt_mlp_output = F.silu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0)))
txt_mlp_output = block_weight.txt_mlp.mlp_2.apply(txt_mlp_output)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16 # Clip to prevent overflow for fp16
...@@ -73,20 +210,15 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -73,20 +210,15 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def infer_calculating(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs): def infer_calculating(self, block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
for index_block, block in enumerate(self.blocks): for idx in range(len(block_weights.blocks)):
block_weight = block_weights.blocks[idx]
encoder_hidden_states, hidden_states = self.infer_block( encoder_hidden_states, hidden_states = self.infer_block(
block=block, block_weight=block_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
) )
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def infer(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out, attention_kwargs): def infer(self, hidden_states, encoder_hidden_states, pre_infer_out, block_weights):
_, temb, image_rotary_emb = pre_infer_out temb, image_rotary_emb = pre_infer_out
encoder_hidden_states, hidden_states = self.infer_func(hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs) encoder_hidden_states, hidden_states = self.infer_func(block_weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from diffusers.utils import deprecate
from diffusers.utils.import_utils import is_torch_npu_available, is_torch_version
from torch import nn
if is_torch_npu_available():
import torch_npu
ACT2CLS = {
"swish": nn.SiLU,
"silu": nn.SiLU,
"mish": nn.Mish,
"gelu": nn.GELU,
"relu": nn.ReLU,
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACT2CLS:
return ACT2CLS[act_fn]()
else:
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
class FP32SiLU(nn.Module):
r"""
SiLU activation function with input upcasted to torch.float32.
"""
def __init__(self):
super().__init__()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = self.proj(hidden_states)
if is_torch_npu_available():
# using torch_npu.npu_geglu can run faster and save memory on NPU.
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class SwiGLU(nn.Module):
r"""
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function. It's similar to
`GEGLU` but uses SiLU / Swish instead of GeLU.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.activation(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://huggingface.co/papers/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class LinearActivation(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.activation = get_activation(activation)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
return self.activation(hidden_states)
...@@ -2,36 +2,49 @@ import json ...@@ -2,36 +2,49 @@ import json
import os import os
import torch import torch
from safetensors import safe_open
try: from lightx2v.utils.envs import *
from .transformer_qwenimage import QwenImageTransformer2DModel from lightx2v.utils.utils import *
except ImportError:
QwenImageTransformer2DModel = None
from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer from .infer.offload.transformer_infer import QwenImageOffloadTransformerInfer
from .infer.post_infer import QwenImagePostInfer from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer from .infer.transformer_infer import QwenImageTransformerInfer
from .weights.post_weights import QwenImagePostWeights
from .weights.pre_weights import QwenImagePreWeights
from .weights.transformer_weights import QwenImageTransformerWeights
class QwenImageTransformerModel: class QwenImageTransformerModel:
pre_weight_class = QwenImagePreWeights
transformer_weight_class = QwenImageTransformerWeights
post_weight_class = QwenImagePostWeights
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.transformer = QwenImageTransformer2DModel.from_pretrained(os.path.join(config.model_path, "transformer")) self.model_path = os.path.join(config.model_path, "transformer")
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
self.target_device = torch.device("cpu") if self.cpu_offload else torch.device("cuda") self.offload_granularity = self.config.get("offload_granularity", "block")
self.transformer.to(self.target_device).to(torch.bfloat16) self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f: with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
transformer_config = json.load(f) transformer_config = json.load(f)
self.in_channels = transformer_config["in_channels"] self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {} self.attention_kwargs = {}
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
self._init_infer_class() self._init_infer_class()
self._init_weights()
self._init_infer() self._init_infer()
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
def _init_infer_class(self): def _init_infer_class(self):
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
...@@ -41,13 +54,145 @@ class QwenImageTransformerModel: ...@@ -41,13 +54,145 @@ class QwenImageTransformerModel:
self.pre_infer_class = QwenImagePreInfer self.pre_infer_class = QwenImagePreInfer
self.post_infer_class = QwenImagePostInfer self.post_infer_class = QwenImagePostInfer
def _init_weights(self, weight_dict=None):
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
# Some layers run with float32 to achieve high accuracy
sensitive_layer = {}
if weight_dict is None:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
# Load quantized weights
assert NotImplementedError
if self.config.get("device_mesh") is not None:
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
self.original_weight_dict = weight_dict
else:
self.original_weight_dict = weight_dict
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
def _should_load_weights(self):
"""Determine if current rank should load weights from disk."""
if self.config.get("device_mesh") is None:
# Single GPU mode
return True
elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads
if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
def _load_weights_distribute(self, weight_dict, is_weight_loader):
global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda"
if is_weight_loader:
meta_dict = {}
for key, tensor in weight_dict.items():
meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}
obj_list = [meta_dict]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
else:
obj_list = [None]
dist.broadcast_object_list(obj_list, src=global_src_rank)
synced_meta_dict = obj_list[0]
distributed_weight_dict = {}
for key, meta in synced_meta_dict.items():
distributed_weight_dict[key] = torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
if target_device == "cuda":
dist.barrier(device_ids=[torch.cuda.current_device()])
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)
if target_device == "cpu":
if is_weight_loader:
gpu_tensor = distributed_weight_dict[key].cuda()
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
if distributed_weight_dict[key].is_pinned():
distributed_weight_dict[key].copy_(distributed_weight_dict[key], non_blocking=True)
else:
dist.broadcast(distributed_weight_dict[key], src=global_src_rank)
if target_device == "cuda":
torch.cuda.synchronize()
else:
for tensor in distributed_weight_dict.values():
if tensor.is_pinned():
tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def _init_infer(self): def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config, self.transformer.transformer_blocks) self.transformer_infer = self.transformer_infer_class(self.config)
self.pre_infer = self.pre_infer_class(self.config, self.transformer.img_in, self.transformer.txt_norm, self.transformer.txt_in, self.transformer.time_text_embed, self.transformer.pos_embed) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config, self.transformer.norm_out, self.transformer.proj_out) self.post_infer = self.post_infer_class(self.config)
def to_cpu(self):
self.pre_weight.to_cpu()
self.transformer_weights.to_cpu()
self.post_weight.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.transformer_weights.to_cuda()
self.post_weight.to_cuda()
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
t = self.scheduler.timesteps[self.scheduler.step_index] t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents latents = self.scheduler.latents
if self.config.task == "i2i": if self.config.task == "i2i":
...@@ -63,7 +208,9 @@ class QwenImageTransformerModel: ...@@ -63,7 +208,9 @@ class QwenImageTransformerModel:
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"] prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out = self.pre_infer.infer(
hidden_states, encoder_hidden_states, _, pre_infer_out = self.pre_infer.infer(
weights=self.pre_weight,
hidden_states=latents_input, hidden_states=latents_input,
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=self.scheduler.guidance, guidance=self.scheduler.guidance,
...@@ -75,14 +222,14 @@ class QwenImageTransformerModel: ...@@ -75,14 +222,14 @@ class QwenImageTransformerModel:
) )
encoder_hidden_states, hidden_states = self.transformer_infer.infer( encoder_hidden_states, hidden_states = self.transformer_infer.infer(
hidden_states=hidden_states, block_weights=self.transformer_weights,
encoder_hidden_states=encoder_hidden_states, hidden_states=hidden_states.unsqueeze(0),
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states=encoder_hidden_states.unsqueeze(0),
pre_infer_out=pre_infer_out, pre_infer_out=pre_infer_out,
attention_kwargs=self.attention_kwargs,
) )
noise_pred = self.post_infer.infer(hidden_states, pre_infer_out[1]) noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])
if self.config.task == "i2i": if self.config.task == "i2i":
noise_pred = noise_pred[:, : latents.size(1)] noise_pred = noise_pred[:, : latents.size(1)]
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
)
class QwenImagePostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.task = config["task"]
self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
assert NotImplementedError
self.lazy_load_file = False
# norm_out
self.add_module(
"norm_out_linear",
MM_WEIGHT_REGISTER[self.mm_type](
"norm_out.linear.weight",
"norm_out.linear.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module("norm_out", LN_WEIGHT_REGISTER["Default"](eps=1e-6))
# proj_out
self.add_module(
"proj_out_linear",
MM_WEIGHT_REGISTER[self.mm_type](
"proj_out.weight",
"proj_out.bias",
self.lazy_load,
self.lazy_load_file,
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment