Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
{
"infer_steps": 50,
"transformer_model_name": "480p_t2v",
"fps": 24,
"target_video_length": 121,
"aspect_ratio": "16:9",
"vae_stride": [4, 16, 16],
"sample_shift": 7.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "sage_attn2",
"cpu_offload": true,
"offload_granularity": "block",
"vae_cpu_offload": false,
"byt5_cpu_offload": false,
"qwen25vl_cpu_offload": true,
"siglip_cpu_offload": false
}
{
"infer_steps": 50,
"transformer_model_name": "480p_t2v",
"fps": 24,
"target_video_length": 121,
"aspect_ratio": "16:9",
"vae_stride": [4, 16, 16],
"sample_shift": 7.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "flash_attn3",
"dit_quantized_ckpt": "/path/to/quant_model.safetensors",
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl"
}
{
"infer_steps": 50,
"transformer_model_name": "480p_i2v",
"fps": 24,
"target_video_length": 121,
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "flash_attn3",
"video_super_resolution": {
"sr_version": "720p_sr_distilled",
"flow_shift": 2.0,
"base_resolution": "480p",
"guidance_scale": 1.0,
"num_inference_steps": 6,
"use_meanflow": true
}
}
......@@ -46,6 +46,9 @@ class FlashAttn2Weight(AttnWeightTemplate):
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func(
q,
k,
......@@ -78,6 +81,9 @@ class FlashAttn3Weight(AttnWeightTemplate):
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
......
......@@ -27,6 +27,11 @@ class UlyssesAttnWeight(AttnWeightTemplate):
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
......@@ -134,9 +139,16 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
......@@ -181,22 +193,23 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
# 异步发起通信后同步
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_req.wait()
sendk_req.wait()
sendv_req.wait()
......@@ -254,6 +267,9 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
......@@ -269,14 +285,15 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
# 异步发起通信后同步
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
......
......@@ -809,7 +809,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
self.weight_scale,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("int8-q8f")
......@@ -840,7 +840,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
fuse_gelu=False,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
......
......@@ -6,6 +6,8 @@ import torch
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from .triton_ops import norm_infer
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
......@@ -165,3 +167,30 @@ class LNWeight(LNWeightTemplate):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor
@LN_WEIGHT_REGISTER("Triton")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else:
self.bias = None
def apply(self, input_tensor):
input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps)
return input_tensor
......@@ -80,14 +80,14 @@ class RMSWeight(RMSWeightTemplate):
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = (input_tensor * self.weight).to(GET_DTYPE())
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
return input_tensor
def state_dict(self, destination=None):
......@@ -111,7 +111,15 @@ class RMSWeight(RMSWeightTemplate):
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
......
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
# TODO: for temporary usage, expecting a refactor
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from torch import Tensor
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["inner_dim"],
)
@triton.jit
def _fused_scale_shift_4d_kernel(
output_ptr,
normalized_ptr,
scale_ptr,
shift_ptr,
rows,
inner_dim,
seq_len,
num_frames,
frame_seqlen,
BLOCK_N: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col_offsets < inner_dim
# Pointers for normalized and output
row_base = pid_row * inner_dim
norm_ptrs = normalized_ptr + row_base + col_offsets
out_ptrs = output_ptr + row_base + col_offsets
# Pointers for scale and shift for 4D
b_idx = pid_row // seq_len
t_idx = pid_row % seq_len
frame_idx_in_batch = t_idx // frame_seqlen
scale_row_idx = b_idx * num_frames + frame_idx_in_batch
scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
scale = tl.load(scale_ptrs, mask=mask, other=0.0)
shift = tl.load(shift_ptrs, mask=mask, other=0.0)
one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
output = normalized * (one + scale) + shift
tl.store(out_ptrs, output, mask=mask)
@triton.jit
def fuse_scale_shift_kernel_blc_opt(
x_ptr,
shift_ptr,
scale_ptr,
y_ptr,
B,
L,
C,
stride_x_b,
stride_x_l,
stride_x_c,
stride_s_b,
stride_s_l,
stride_s_c,
stride_sc_b,
stride_sc_l,
stride_sc_c,
SCALE_IS_SCALAR: tl.constexpr,
SHIFT_IS_SCALAR: tl.constexpr,
BLOCK_L: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_l = tl.program_id(0)
pid_c = tl.program_id(1)
pid_b = tl.program_id(2)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_l = l_offsets < L
mask_c = c_offsets < C
mask = mask_l[:, None] & mask_c[None, :]
x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c
x = tl.load(x_ptr + x_off, mask=mask, other=0)
if SHIFT_IS_SCALAR:
shift_val = tl.load(shift_ptr)
shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
else:
s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c
shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
if SCALE_IS_SCALAR:
scale_val = tl.load(scale_ptr)
scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
else:
sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c
scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
y = x * (1 + scale) + shift
tl.store(y_ptr + x_off, y, mask=mask)
def fuse_scale_shift_kernel(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
block_l: int = 128,
block_c: int = 128,
):
assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
B, L, C = x.shape
output = torch.empty_like(x)
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
x_2d = x.view(rows, C)
output_2d = output.view(rows, C)
grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa
num_frames = scale.shape[1]
assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
_fused_scale_shift_4d_kernel[grid](
output_2d,
x_2d,
scale_reshaped,
shift_reshaped,
rows,
C,
L,
num_frames,
frame_seqlen,
)
else:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
scale_blc = scale.reshape(1)
elif scale.dim() == 2:
scale_blc = scale[:, None, :]
elif scale.dim() == 3:
scale_blc = scale
else:
raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
shift_blc = shift.reshape(1)
elif shift.dim() == 2:
shift_blc = shift[:, None, :]
elif shift.dim() == 3:
shift_blc = shift
else:
# broadcast later via expand if possible
shift_blc = shift
need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
if not need_scale_scalar:
scale_exp = scale_blc.expand(B, L, C)
s_sb, s_sl, s_sc = scale_exp.stride()
else:
s_sb = s_sl = s_sc = 0
if not need_shift_scalar:
shift_exp = shift_blc.expand(B, L, C)
sh_sb, sh_sl, sh_sc = shift_exp.stride()
else:
sh_sb = sh_sl = sh_sc = 0
# If both scalars and both zero, copy fast-path
if need_scale_scalar and need_shift_scalar:
if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
output.copy_(x)
return output
grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
fuse_scale_shift_kernel_blc_opt[grid](
x,
shift_blc if need_shift_scalar else shift_exp,
scale_blc if need_scale_scalar else scale_exp,
output,
B,
L,
C,
x.stride(0),
x.stride(1),
x.stride(2),
sh_sb,
sh_sl,
sh_sc,
s_sb,
s_sl,
s_sc,
SCALE_IS_SCALAR=need_scale_scalar,
SHIFT_IS_SCALAR=need_shift_scalar,
BLOCK_L=block_l,
BLOCK_C=block_c,
num_warps=4,
num_stages=2,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
],
key=["head_size", "interleaved"],
)
@triton.jit
def _rotary_embedding_kernel(
output_ptr,
x_ptr,
cos_ptr,
sin_ptr,
num_heads,
head_size,
num_tokens,
stride_x_row,
stride_cos_row,
stride_sin_row,
interleaved: tl.constexpr,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
output_row_ptr = output_ptr + row_idx * stride_x_row
# half size for x1 and x2
head_size_half = head_size // 2
for block_start in range(0, head_size_half, BLOCK_HS_HALF):
offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
mask = offsets_half < head_size_half
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
offsets_x1 = 2 * offsets_half
offsets_x2 = 2 * offsets_half + 1
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
x1_fp32 = x1_vals.to(tl.float32)
x2_fp32 = x2_vals.to(tl.float32)
cos_fp32 = cos_vals.to(tl.float32)
sin_fp32 = sin_vals.to(tl.float32)
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
bsz, num_tokens, num_heads, head_size = x.shape
else:
num_tokens, num_heads, head_size = x.shape
bsz = 1
assert head_size % 2 == 0, "head_size must be divisible by 2"
x_reshaped = x.view(-1, head_size)
output_reshaped = output.view(-1, head_size)
# num_tokens per head, 1 token per block
grid = (bsz * num_tokens * num_heads,)
if interleaved and cos.shape[-1] == head_size:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].contiguous()
else:
cos = cos.contiguous()
sin = sin.contiguous()
_rotary_embedding_kernel[grid](
output_reshaped,
x_reshaped,
cos,
sin,
num_heads,
head_size,
num_tokens,
x_reshaped.stride(0),
cos.stride(0),
sin.stride(0),
interleaved,
)
return output
# RMSNorm-fp32
def maybe_contiguous_lastdim(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def maybe_contiguous(x):
return x.contiguous() if x is not None else None
def triton_autotune_configs():
# Return configs with a valid warp count for the current device
configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block = 1024
# Default to warp size 32 if not defined by device
warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@triton.autotune(
configs=triton_autotune_configs(),
key=[
"N",
"HAS_RESIDUAL",
"STORE_RESIDUAL_OUT",
"IS_RMS_NORM",
"HAS_BIAS",
"HAS_WEIGHT",
"HAS_X1",
"HAS_W1",
"HAS_B1",
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
DROPOUT_MASK1,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
zero_centered_weight, # If true, add 1.0 to the weight
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w += 1.0
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
y = x_hat * w + b if HAS_BIAS else x_hat * w
else:
y = x_hat + b if HAS_BIAS else x_hat
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w1 += 1.0
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
residual_dtype: Optional[torch.dtype] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[Tensor] = None,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
if residual is not None:
residual_dtype = residual.dtype
if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None):
residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype)
else:
residual_out = None
y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
x,
weight,
bias,
eps,
out,
residual=residual,
x1=x1,
weight1=weight1,
bias1=bias1,
dropout_p=dropout_p,
rowscale=rowscale,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
residual_out=residual_out,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if residual_out is None:
residual_out = x
return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def _layer_norm_fwd_impl(
x: Tensor,
weight: Optional[Tensor],
bias: Tensor,
eps: float,
out: Tensor,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
assert out.shape == x.shape
assert out.stride(-1) == 1
if residual_out is not None:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
if x1 is not None:
dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask1 = None
else:
dropout_mask, dropout_mask1 = None, None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
x,
out,
weight if weight is not None else x, # unused when HAS_WEIGHT == False
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
dropout_mask,
dropout_mask1,
mean,
rstd,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int(zero_centered_weight),
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
HAS_X1=x1 is not None,
HAS_W1=weight1 is not None,
HAS_B1=bias1 is not None,
)
return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
class LayerNormFn:
@staticmethod
def forward(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
if residual is not None:
assert residual.shape == x_shape_og
residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
# weight can be None when elementwise_affine=False for LayerNorm
if weight is not None:
weight = weight.contiguous()
bias = maybe_contiguous(bias)
weight1 = maybe_contiguous(weight1)
bias1 = maybe_contiguous(bias1)
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
out_dtype=out_dtype,
residual_dtype=residual_dtype,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out,
)
y = y.reshape(x_shape_og)
return y
def layer_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
is_rms_norm,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _norm_infer_kernel(
X,
Y,
W,
B,
stride_x_row,
stride_y_row,
M,
N,
eps,
IS_RMS_NORM: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_WEIGHT:
W += 0
if HAS_BIAS:
B += 0
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
y = x_hat * w
else:
y = x_hat
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y += b
tl.store(Y + cols, y, mask=cols < N)
def norm_infer(
x: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
is_rms_norm: bool = False,
out: Optional[Tensor] = None,
):
M, N = x.shape
assert x.stride(-1) == 1
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.shape == (N,)
assert bias.stride(-1) == 1
if out is None:
out = torch.empty_like(x)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_N // 256, 1), 8)
_norm_infer_kernel[(M,)](
x,
out,
weight if weight is not None else x, # dummy when HAS_WEIGHT=False
bias if bias is not None else x, # dummy when HAS_BIAS=False
x.stride(0),
out.stride(0),
M,
N,
eps,
IS_RMS_NORM=is_rms_norm,
HAS_WEIGHT=weight is not None,
HAS_BIAS=bias is not None,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
)
return out
def rms_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
True,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
......@@ -5,6 +5,7 @@ import torch.distributed as dist
from loguru import logger
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
......@@ -49,6 +50,7 @@ def main():
"wan2.2_moe_distill",
"qwen_image",
"wan2.2_animate",
"hunyuan_video_1.5",
],
default="wan2.1",
)
......
import json
def closest_color(requested_color):
import webcolors
min_colors = {}
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - requested_color[0]) ** 2
gd = (g_c - requested_color[1]) ** 2
bd = (b_c - requested_color[2]) ** 2
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def convert_rgb_to_names(rgb_tuple):
try:
import webcolors
color_name = webcolors.rgb_to_name(rgb_tuple)
except ValueError:
color_name = closest_color(rgb_tuple)
return color_name
class MultilingualPromptFormat:
def __init__(
self,
font_path: str = "assets/glyph_sdxl_assets/multilingual_10-lang_idx.json",
color_path: str = "assets/glyph_sdxl_assets/color_idx.json",
):
with open(font_path, "r") as f:
self.font_dict = json.load(f)
with open(color_path, "r") as f:
self.color_dict = json.load(f)
def format_prompt(self, texts, styles):
"""
Text "{text}" in {color}, {type}.
"""
prompt = ""
for text, style in zip(texts, styles):
text_prompt = f'Text "{text}"'
attr_list = []
# format color
if style["color"] is not None:
import webcolors
hex_color = style["color"]
rgb_color = webcolors.hex_to_rgb(hex_color)
color_name = convert_rgb_to_names(rgb_color)
attr_list.append(f"<color-{self.color_dict[color_name]}>")
# format font
if style["font-family"] is not None:
attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>")
attr_suffix = ", ".join(attr_list)
text_prompt += " in " + attr_suffix
text_prompt += ". "
else:
text_prompt += ". "
prompt = prompt + text_prompt
return prompt
import glob
import json
import os
import re
import torch
import torch.nn as nn
from safetensors import safe_open
from transformers import AutoTokenizer, T5ForConditionalGeneration
from .format_prompt import MultilingualPromptFormat
def add_special_token(
tokenizer,
text_encoder,
add_color,
add_font,
color_ann_path,
font_ann_path,
multilingual=False,
):
"""
Add special tokens for color and font to tokenizer and text encoder.
Args:
tokenizer: Huggingface tokenizer.
text_encoder: Huggingface T5 encoder.
add_color (bool): Whether to add color tokens.
add_font (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
multilingual (bool): Whether to use multilingual font tokens.
"""
with open(font_ann_path, "r") as f:
idx_font_dict = json.load(f)
with open(color_ann_path, "r") as f:
idx_color_dict = json.load(f)
if multilingual:
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
else:
font_token = [f"<font-{i}>" for i in range(len(idx_font_dict))]
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
additional_special_tokens = []
if add_color:
additional_special_tokens += color_token
if add_font:
additional_special_tokens += font_token
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
def load_byt5_and_byt5_tokenizer(
byt5_name="google/byt5-small",
special_token=False,
color_special_token=False,
font_special_token=False,
color_ann_path="assets/color_idx.json",
font_ann_path="assets/font_idx_512.json",
huggingface_cache_dir=None,
multilingual=False,
device=None,
):
"""
Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed.
Args:
byt5_name (str): Model name or path.
special_token (bool): Whether to add special tokens.
color_special_token (bool): Whether to add color tokens.
font_special_token (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
huggingface_cache_dir (str): Huggingface cache directory.
multilingual (bool): Whether to use multilingual font tokens.
device (str or torch.device): Device to load the model onto.
Returns:
tuple: (byt5_text_encoder, byt5_tokenizer)
"""
byt5_tokenizer = AutoTokenizer.from_pretrained(
byt5_name,
cache_dir=huggingface_cache_dir,
)
byt5_text_encoder = T5ForConditionalGeneration.from_pretrained(
byt5_name,
cache_dir=huggingface_cache_dir,
).get_encoder()
if "cuda" not in str(device):
device = torch.device(device)
else:
device = torch.device(device)
byt5_text_encoder = byt5_text_encoder.to(device)
if special_token:
add_special_token(
byt5_tokenizer,
byt5_text_encoder,
add_color=color_special_token,
add_font=font_special_token,
color_ann_path=color_ann_path,
font_ann_path=font_ann_path,
multilingual=multilingual,
)
return byt5_text_encoder, byt5_tokenizer
class ByT5Mapper(nn.Module):
"""
ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection.
Args:
in_dim (int): Input dimension (must equal out_dim if use_residual).
out_dim (int): Output dimension after second linear layer.
hidden_dim (int): Hidden dimension for intermediate layer.
out_dim1 (int): Final output dimension.
use_residual (bool): Whether to use residual connection (default: True).
"""
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.fc3 = nn.Linear(out_dim, out_dim1)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
"""
Forward pass for ByT5Mapper.
Args:
x (Tensor): Input tensor of shape (..., in_dim).
Returns:
Tensor: Output tensor of shape (..., out_dim1).
"""
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_residual:
x2 = x2 + residual
return x2
class ByT5TextEncoder:
def __init__(
self,
config,
device=torch.cuda.current_device(),
checkpoint_path=None,
byt5_max_length=256,
cpu_offload=False,
):
self.cpu_offload = cpu_offload
self.config = config
self.device = device
self.byt5_max_length = byt5_max_length
self.enable_cfg = config.get("enable_cfg", False)
byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small")
byT5_ckpt_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "checkpoints/byt5_model.pt")
multilingual_prompt_format_color_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/color_idx.json")
multilingual_prompt_format_font_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/multilingual_10-lang_idx.json")
byt5_args = dict(
byT5_google_path=byT5_google_path,
byT5_ckpt_path=byT5_ckpt_path,
multilingual_prompt_format_color_path=multilingual_prompt_format_color_path,
multilingual_prompt_format_font_path=multilingual_prompt_format_font_path,
byt5_max_length=byt5_max_length,
)
self.byt5_tokenizer, self.byt5_model, self.byt5_max_length = self.create_byt5(byt5_args, device)
self.byt5_model = self.byt5_model.to(device=device)
self.prompt_format = MultilingualPromptFormat(font_path=multilingual_prompt_format_font_path, color_path=multilingual_prompt_format_color_path)
self.byt5_mapper = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.config["hidden_size"], use_residual=False).to(torch.bfloat16)
byt5_mapper_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"])
safetensors_files = glob.glob(os.path.join(byt5_mapper_model_path, "*.safetensors"))
byt5_mapper_state_dict = {}
for safetensor_path in safetensors_files:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
byt5_mapper_state_dict.update({key.replace("byt5_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "byt5_in" in key})
self.byt5_mapper.load_state_dict(byt5_mapper_state_dict)
self.byt5_mapper.to(device=device)
def create_byt5(self, args, device):
"""
Create ByT5 tokenizer and encoder, load weights if provided.
Args:
args (dict): Configuration dictionary.
device (str or torch.device): Device to load the model onto.
Returns:
tuple: (byt5_tokenizer, byt5_model, byt5_max_length)
"""
byt5_max_length = args["byt5_max_length"]
byt5_config = dict(
byt5_name=args["byT5_google_path"],
special_token=True,
color_special_token=True,
font_special_token=True,
color_ann_path=args["multilingual_prompt_format_color_path"],
font_ann_path=args["multilingual_prompt_format_font_path"],
multilingual=True,
)
huggingface_cache_dir = None
byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer(
**byt5_config,
huggingface_cache_dir=huggingface_cache_dir,
device=device,
)
# Load custom checkpoint if provided
if args["byT5_ckpt_path"] is not None:
if "cuda" not in str(device):
byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device)
else:
byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device)
if "state_dict" in byt5_state_dict:
sd = byt5_state_dict["state_dict"]
newsd = {}
for k, v in sd.items():
if k.startswith("module.text_tower.encoder."):
newsd[k[len("module.text_tower.encoder.") :]] = v
byt5_state_dict = newsd
byt5_model.load_state_dict(byt5_state_dict)
byt5_model.requires_grad_(False)
return byt5_tokenizer, byt5_model, byt5_max_length
def _extract_glyph_texts(self, prompt):
"""
Extract glyph texts from prompt using regex pattern.
Args:
prompt: Input prompt string
Returns:
List of extracted glyph texts
"""
pattern = r"\"(.*?)\"|“(.*?)”"
matches = re.findall(pattern, prompt)
result = [match[0] or match[1] for match in matches]
result = list(dict.fromkeys(result)) if len(result) > 1 else result
return result
def _process_single_byt5_prompt(self, prompt_text, device):
"""
Process a single prompt for byT5 encoding.
Args:
prompt_text: The prompt text to process
device: Target device for tensors
Returns:
Tuple of (byt5_embeddings, byt5_mask)
"""
byt5_embeddings = torch.zeros((1, self.byt5_max_length, 1472), device=device)
byt5_mask = torch.zeros((1, self.byt5_max_length), device=device, dtype=torch.int64)
glyph_texts = self._extract_glyph_texts(prompt_text)
if len(glyph_texts) > 0:
text_styles = [{"color": None, "font-family": None} for _ in range(len(glyph_texts))]
formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles)
text_ids, text_mask = self.get_byt5_text_tokens(self.byt5_tokenizer, self.byt5_max_length, formatted_text)
text_ids = text_ids.to("cuda")
text_mask = text_mask.to("cuda")
byt5_outputs = self.byt5_model(text_ids, attention_mask=text_mask.float())
byt5_embeddings = byt5_outputs[0]
byt5_mask = text_mask
return byt5_embeddings, byt5_mask
def _prepare_byt5_embeddings(self, prompts):
if isinstance(prompts, str):
prompt_list = [prompts]
elif isinstance(prompts, list):
prompt_list = prompts
else:
raise ValueError("prompts must be str or list of str")
positive_embeddings = []
positive_masks = []
negative_embeddings = []
negative_masks = []
for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, "cuda")
positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", "cuda")
negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask)
byt5_positive = torch.cat(positive_embeddings, dim=0)
byt5_positive_mask = torch.cat(positive_masks, dim=0)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
byt5_negative = torch.cat(negative_embeddings, dim=0)
byt5_negative_mask = torch.cat(negative_masks, dim=0)
byt5_embeddings = torch.cat([byt5_negative, byt5_positive], dim=0)
byt5_masks = torch.cat([byt5_negative_mask, byt5_positive_mask], dim=0)
else:
byt5_embeddings = byt5_positive
byt5_masks = byt5_positive_mask
return byt5_embeddings, byt5_masks
@torch.no_grad()
def infer(self, prompts):
if self.cpu_offload:
self.byt5_model = self.byt5_model.to("cuda")
self.byt5_mapper = self.byt5_mapper.to("cuda")
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload:
self.byt5_model = self.byt5_model.to("cpu")
self.byt5_mapper = self.byt5_mapper.to("cpu")
return byt5_features, byt5_masks
if __name__ == "__main__":
byt5 = ByT5TextEncoder(config={"transformer_model_name": "480p_t2v", "hidden_size": 2048}, device="cuda", checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5")
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
byt5_features, byt5_masks = byt5.infer(prompt)
print(byt5_features.shape, byt5_features.sum())
print(byt5_masks.shape, byt5_masks.sum())
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gc
import sys
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import loguru
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
)
from transformers.utils import ModelOutput
current_dir = Path(__file__).resolve().parent
project_root = current_dir.parent.parent.parent.parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from lightx2v.models.input_encoders.hf.q_linear import ( # noqa E402
Q8FQuantLinearFp8, # noqa E402
Q8FQuantLinearInt8, # noqa E402
SglQuantLinearFp8, # noqa E402
TorchaoQuantLinearInt8, # noqa E402
VllmQuantLinearInt8, # noqa E402
)
def use_default(value, default):
"""Utility: return value if not None, else default."""
return value if value is not None else default
# Prompt templates for different models and tasks
__all__ = [
"C_SCALE",
"PROMPT_TEMPLATE",
"MODEL_BASE",
]
# =================== Constant Values =====================
# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
# overflow error when tensorboard logging values.
C_SCALE = 1_000_000_000_000_000
PROMPT_TEMPLATE_ENCODE_IMAGE_JSON = [
{
"role": "system",
"content": "You are a helpful assistant. Describe the image by detailing the following aspects: \
1. The main content and theme of the image. \
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
3. The background environment, light, style and atmosphere.",
},
{"role": "user", "content": "{}"},
]
PROMPT_TEMPLATE_ENCODE_VIDEO_JSON = [
{
"role": "system",
"content": "You are a helpful assistant. Describe the video by detailing the following aspects: \
1. The main content and theme of the video. \
2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
4. background environment, light, style and atmosphere. \
5. camera angles, movements, and transitions used in the video.",
},
{"role": "user", "content": "{}"},
]
PROMPT_TEMPLATE = {
"li-dit-encode-image-json": {"template": PROMPT_TEMPLATE_ENCODE_IMAGE_JSON, "crop_start": -1}, # auto-calculate crop_start
"li-dit-encode-video-json": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO_JSON, "crop_start": -1}, # auto-calculate crop_start
}
MODEL_BASE = os.getenv("MODEL_BASE", "")
TEXT_ENCODER_PATH = {
"qwen-2.5vl-7b": f"{MODEL_BASE}/Qwen2.5-VL-7B-Instruct",
}
TOKENIZER_PATH = {
"qwen-2.5vl-7b": f"{MODEL_BASE}/Qwen2.5-VL-7B-Instruct",
}
PRECISION_TO_TYPE = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def replace_linear(module, new_linear_cls):
for name, child in list(module.named_children()):
if isinstance(child, nn.Linear):
new_linear = new_linear_cls(child.in_features, child.out_features, bias=(child.bias is not None))
new_linear.to(device=next(child.parameters(), None).device if any(True for _ in child.parameters()) else torch.device("cpu"))
setattr(module, name, new_linear)
else:
replace_linear(child, new_linear_cls)
def load_text_encoder(
text_encoder_type, text_encoder_precision=None, text_encoder_path=None, logger=None, device=None, text_encoder_quantized=False, text_encoder_quant_scheme=None, text_encoder_quant_ckpt=None
):
if text_encoder_path is None:
if text_encoder_type not in TEXT_ENCODER_PATH:
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
if text_encoder_quantized:
config = AutoConfig.from_pretrained(text_encoder_path)
with init_empty_weights():
text_encoder = AutoModel.from_config(config)
text_encoder = text_encoder.language_model
if text_encoder_quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
elif text_encoder_quant_scheme in ["fp8", "fp8-sgl"]:
linear_cls = SglQuantLinearFp8
elif text_encoder_quant_scheme == "int8-torchao":
linear_cls = TorchaoQuantLinearInt8
elif text_encoder_quant_scheme == "int8-q8f":
linear_cls = Q8FQuantLinearInt8
elif text_encoder_quant_scheme == "fp8-q8f":
linear_cls = Q8FQuantLinearFp8
else:
NotImplementedError(f"Unsupported Qwen25_vl quant scheme: {text_encoder_quant_scheme}")
replace_linear(text_encoder.layers, linear_cls)
weight_dict = load_file(text_encoder_quant_ckpt, device=str(device))
new_w_dict = {}
for key in weight_dict.keys():
if key == "lm_head.weight":
continue
new_w_dict[key.replace("model.", "")] = weight_dict[key]
del weight_dict
torch.cuda.empty_cache()
gc.collect()
text_encoder.load_state_dict(new_w_dict, assign=True)
else:
text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True)
text_encoder = text_encoder.language_model
text_encoder.final_layer_norm = text_encoder.norm
# from_pretrained will ensure that the model is in eval mode.
if text_encoder_precision is not None:
text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
text_encoder.requires_grad_(False)
if device is not None:
text_encoder = text_encoder.to(device)
return text_encoder, text_encoder_path
def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right", logger=None):
processor = None
if tokenizer_path is None:
if tokenizer_type not in TOKENIZER_PATH:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
tokenizer_path = TOKENIZER_PATH[tokenizer_type]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side)
return tokenizer, tokenizer_path, processor
@dataclass
class TextEncoderModelOutput(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
List of decoded texts.
"""
hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.LongTensor] = None
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
text_outputs: Optional[list] = None
image_features: Optional[list] = None
class TextEncoder(nn.Module):
def __init__(
self,
text_encoder_type: str,
max_length: int,
text_encoder_precision: Optional[str] = None,
text_encoder_path: Optional[str] = None,
tokenizer_type: Optional[str] = None,
tokenizer_path: Optional[str] = None,
output_key: Optional[str] = None,
use_attention_mask: bool = True,
prompt_template: Optional[dict] = None,
prompt_template_video: Optional[dict] = None,
hidden_state_skip_layer: Optional[int] = None,
apply_final_norm: bool = False,
reproduce: bool = False,
logger=None,
device=None,
qwen25vl_quantized=False,
qwen25vl_quant_scheme=None,
qwen25vl_quant_ckpt=None,
):
super().__init__()
self.text_encoder_type = text_encoder_type
self.max_length = max_length
self.precision = text_encoder_precision
self.model_path = text_encoder_path
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
self.use_attention_mask = use_attention_mask
if prompt_template_video is not None:
assert use_attention_mask is True, "Attention mask is True required when training videos."
self.prompt_template = prompt_template
self.prompt_template_video = prompt_template_video
self.hidden_state_skip_layer = hidden_state_skip_layer
self.apply_final_norm = apply_final_norm
self.reproduce = reproduce
self.logger = logger
self.use_template = self.prompt_template is not None
if self.use_template:
assert isinstance(self.prompt_template, dict) and "template" in self.prompt_template, f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
assert "{}" in str(self.prompt_template["template"]), f"`prompt_template['template']` must contain a placeholder `{{}}` for the input text, got {self.prompt_template['template']}"
self.use_video_template = self.prompt_template_video is not None
if self.use_video_template:
if self.prompt_template_video is not None:
assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, (
f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
)
assert "{}" in str(self.prompt_template_video["template"]), (
f"`prompt_template_video['template']` must contain a placeholder `{{}}` for the input text, got {self.prompt_template_video['template']}"
)
if text_encoder_type != "qwen-2.5vl-7b":
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
self.output_key = output_key or "last_hidden_state"
self.model, self.model_path = load_text_encoder(
text_encoder_type=self.text_encoder_type,
text_encoder_precision=self.precision,
text_encoder_path=self.model_path,
logger=self.logger,
device=device,
text_encoder_quantized=qwen25vl_quantized,
text_encoder_quant_scheme=qwen25vl_quant_scheme,
text_encoder_quant_ckpt=qwen25vl_quant_ckpt,
)
self.tokenizer, self.tokenizer_path, self.processor = load_tokenizer(
tokenizer_type=self.tokenizer_type,
tokenizer_path=self.tokenizer_path,
padding_side="right",
logger=self.logger,
)
# pre-calculate crop_start for image and video
if self.use_template and self.prompt_template is not None:
self.text2tokens("a photo of a cat", data_type="image")
# self.logger.info(f"crop_start for image: {self.prompt_template['crop_start']}")
if self.use_video_template and self.prompt_template_video is not None:
self.text2tokens("a photo of a cat", data_type="video")
# self.logger.info(f"crop_start for video: {self.prompt_template_video['crop_start']}")
@property
def dtype(self):
return self.model.dtype
@property
def device(self):
return self.model.device
def __repr__(self):
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
@staticmethod
def apply_text_to_template(text, template, prevent_empty_text=True):
"""
Apply text to template.
Args:
text (str): Input text.
template (str or list): Template string or list of chat conversation.
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
by adding a space. Defaults to True.
"""
if isinstance(template, str):
# Will send string to tokenizer. Used for llm
return template.format(text)
elif isinstance(template, list):
# For JSON list template format (chat conversation)
# Create a deep copy to avoid modifying the original template
template_copy = deepcopy(template)
for item in template_copy:
if isinstance(item, dict) and "content" in item:
# Replace placeholder with text in the content field
item["content"] = item["content"].format(text if text else (" " if prevent_empty_text else ""))
return template_copy
else:
raise TypeError(f"Unsupported template type: {type(template)}")
def calculate_crop_start(self, tokenized_input):
"""
Automatically calculate the crop_start position based on identifying user tokens.
Args:
tokenized_input: The output from the tokenizer containing input_ids
Returns:
int: The position where the actual prompt content begins (after user markers)
"""
input_ids = tokenized_input["input_ids"][0].tolist() # Get the first example's tokens
# Qwen user marker
marker = "<|im_start|>user\n"
# Tokenize just the marker to get its token IDs
marker_tokens = self.tokenizer(marker, add_special_tokens=False)["input_ids"]
# Find the end position of the marker in the input sequence
for i in range(len(input_ids) - len(marker_tokens) + 1):
if input_ids[i : i + len(marker_tokens)] == marker_tokens:
# Return the position after the marker
# print(f"crop_start: {i + len(marker_tokens)}, {self.tokenizer.decode(tokenized_input["input_ids"][0][i:i+len(marker_tokens)+10])}") # check crop_start
return i + len(marker_tokens)
# If marker not found, try to find based on special tokens
if hasattr(self.tokenizer, "special_tokens_map"):
# Check for user token or any other special token that might indicate user input start
for token_name, token_value in self.tokenizer.special_tokens_map.items():
if "user" in token_name.lower():
user_token_id = self.tokenizer.convert_tokens_to_ids(token_value)
if user_token_id in input_ids:
return input_ids.index(user_token_id) + 1
# Default fallback: return 0 (no cropping)
return 0
def text2tokens(self, text, data_type="image", max_length=300):
"""
Tokenize the input text.
Args:
text (str or list): Input text.
"""
tokenize_input_type = "str"
if self.use_template or self.use_video_template:
if data_type == "image":
prompt_template = self.prompt_template["template"]
crop_start = self.prompt_template.get("crop_start", -1)
elif data_type == "video":
prompt_template = self.prompt_template_video["template"]
crop_start = self.prompt_template_video.get("crop_start", -1)
else:
raise ValueError(f"Unsupported data type: {data_type}")
if isinstance(text, (list, tuple)):
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
if isinstance(text[0], list):
tokenize_input_type = "list"
elif isinstance(text, str):
text = self.apply_text_to_template(text, prompt_template)
if isinstance(text, list):
tokenize_input_type = "list"
else:
raise TypeError(f"Unsupported text type: {type(text)}")
# First pass: tokenize with arbitrary max_length to find crop_start
if crop_start == -1:
# Use temporary max_length for the first pass (large enough)
temp_kwargs = dict(
truncation=True,
max_length=256, # Temporary large value
padding="max_length",
return_tensors="pt",
)
# First tokenization pass to calculate crop_start
if tokenize_input_type == "str":
temp_tokenized = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
**temp_kwargs,
)
elif tokenize_input_type == "list":
temp_tokenized = self.tokenizer.apply_chat_template(
text,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**temp_kwargs,
)
# Calculate the crop_start from this first pass
crop_start = self.calculate_crop_start(temp_tokenized)
# Store the calculated crop_start for future use
if data_type == "image":
self.prompt_template["crop_start"] = crop_start
else:
self.prompt_template_video["crop_start"] = crop_start
else:
crop_start = 0
# Second pass: tokenize with the proper max_length using the found crop_start
kwargs = dict(
truncation=True,
max_length=max_length + (crop_start if crop_start > 0 else 0),
padding="max_length",
return_tensors="pt",
)
if tokenize_input_type == "str":
tokenized_output = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
**kwargs,
)
elif tokenize_input_type == "list":
tokenized_output = self.tokenizer.apply_chat_template(
text,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**kwargs,
)
else:
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
return tokenized_output
def encode(
self,
batch_encoding,
use_attention_mask=None,
output_hidden_states=False,
do_sample=None,
hidden_state_skip_layer=None,
return_texts=False,
data_type="image",
device=None,
semantic_images=None,
is_uncond=False,
):
"""
Args:
batch_encoding (dict): Batch encoding from tokenizer.
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
Defaults to None.
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
output_hidden_states will be set True. Defaults to False.
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
When self.produce is False, do_sample is set to True by default.
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
If None, self.output_key will be used. Defaults to None.
return_texts (bool): Whether to return the decoded texts. Defaults to False.
"""
device = self.model.device if device is None else device
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
do_sample = use_default(do_sample, not self.reproduce)
attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
outputs = self.model(
input_ids=batch_encoding["input_ids"].to(device),
attention_mask=attention_mask,
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
if hidden_state_skip_layer > 0 and self.apply_final_norm:
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
else:
last_hidden_state = outputs[self.output_key]
# Remove hidden states of instruction tokens, only keep prompt tokens.
if self.use_template:
if data_type == "image":
crop_start = self.prompt_template.get("crop_start", 0)
elif data_type == "video":
crop_start = self.prompt_template_video.get("crop_start", 0)
else:
raise ValueError(f"Unsupported data type: {data_type}")
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
if output_hidden_states:
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
return TextEncoderModelOutput(last_hidden_state, attention_mask)
def forward(
self,
text,
use_attention_mask=None,
output_hidden_states=False,
do_sample=False,
hidden_state_skip_layer=None,
return_texts=False,
):
batch_encoding = self.text2tokens(text, max_length=self.max_length)
return self.encode(
batch_encoding,
use_attention_mask=use_attention_mask,
output_hidden_states=output_hidden_states,
do_sample=do_sample,
hidden_state_skip_layer=hidden_state_skip_layer,
return_texts=return_texts,
)
class Qwen25VL_TextEncoder:
def __init__(
self,
text_len=1000,
dtype=torch.float16,
device=torch.cuda.current_device(),
checkpoint_path=None,
cpu_offload=False,
qwen25vl_quantized=False,
qwen25vl_quant_scheme=None,
qwen25vl_quant_ckpt=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.cpu_offload = cpu_offload
self.qwen25vl_quantized = qwen25vl_quantized
self.qwen25vl_quant_scheme = qwen25vl_quant_scheme
if self.qwen25vl_quantized:
assert self.qwen25vl_quant_scheme is not None
self.qwen25vl_quant_ckpt = qwen25vl_quant_ckpt
self.num_videos_per_prompt = 1
self.text_encoder = TextEncoder(
text_encoder_type="qwen-2.5vl-7b", # TODO: 不要用 qwen, 改成 llm
tokenizer_type="qwen-2.5vl-7b",
text_encoder_path=checkpoint_path,
max_length=text_len,
text_encoder_precision="fp16",
prompt_template=PROMPT_TEMPLATE["li-dit-encode-image-json"],
prompt_template_video=PROMPT_TEMPLATE["li-dit-encode-video-json"],
hidden_state_skip_layer=2,
apply_final_norm=False,
reproduce=False,
logger=loguru.logger,
device=device,
qwen25vl_quantized=qwen25vl_quantized,
qwen25vl_quant_scheme=qwen25vl_quant_scheme,
qwen25vl_quant_ckpt=qwen25vl_quant_ckpt,
)
def infer(self, texts):
if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cuda")
text_inputs = self.text_encoder.text2tokens(texts, data_type="video", max_length=self.text_len)
prompt_outputs = self.text_encoder.encode(text_inputs, data_type="video", device="cuda")
if self.cpu_offload:
self.text_encoder = self.text_encoder.to("cpu")
prompt_embeds = prompt_outputs.hidden_state
attention_mask = prompt_outputs.attention_mask
if attention_mask is not None:
attention_mask = attention_mask.cuda()
_, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, self.num_videos_per_prompt)
attention_mask = attention_mask.view(self.num_videos_per_prompt, seq_len)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device="cuda")
seq_len = prompt_embeds.shape[1]
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, self.num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(self.num_videos_per_prompt, seq_len, -1)
return prompt_embeds, attention_mask
if __name__ == "__main__":
text_encoder_path = "/data/nvme0/models/hy1118/ckpts/hunyuanvideo-1.5/text_encoder/llm"
device = "cuda"
import torch.nn.functional as F
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
negative_prompt = ""
model = Qwen25VL_TextEncoder(
text_len=1000,
dtype=torch.float16,
device=device,
checkpoint_path=text_encoder_path,
cpu_offload=False,
qwen25vl_quantized=True,
qwen25vl_quant_scheme="int8-q8f",
qwen25vl_quant_ckpt="/data/nvme0/models/hy1118/quant_ckpts/qwen25vl-llm-int8.safetensors",
)
prompt_embeds, attention_mask = model.infer([prompt])
print(f"prompt_embeds: {prompt_embeds}, {prompt_embeds.shape}")
a = torch.load("prompt_embeds.pth")
# print(f"attention_mask: {attention_mask}, {attention_mask.sum()}, {attention_mask.shape}")
print(F.cosine_similarity(prompt_embeds.flatten().unsqueeze(0), a.flatten().unsqueeze(0), dim=1))
negative_prompt_embeds, negative_attention_mask = model.infer([negative_prompt])
print(f"negative_prompt_embeds: {negative_prompt_embeds}, {negative_prompt_embeds.shape}")
b = torch.load("negative_prompt_embeds.pth")
print(F.cosine_similarity(negative_prompt_embeds.flatten().unsqueeze(0), b.flatten().unsqueeze(0), dim=1))
# print(f"negative_attention_mask: {negative_attention_mask}, {negative_attention_mask.sum()}, {negative_attention_mask.shape}")
import glob
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from safetensors.torch import safe_open
from transformers import SiglipImageProcessor, SiglipVisionModel
from transformers.utils import ModelOutput
PRECISION_TO_TYPE = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VISION_ENCODER_PATH = {}
def use_default(value, default):
return value if value is not None else default
def load_vision_encoder(
vision_encoder_type,
vision_encoder_precision=None,
vision_encoder_path=None,
logger=None,
device=None,
):
if vision_encoder_path is None:
vision_encoder_path = VISION_ENCODER_PATH[vision_encoder_type]
if vision_encoder_type == "siglip":
vision_encoder = SiglipVisionModel.from_pretrained(vision_encoder_path, subfolder="image_encoder")
else:
raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}")
# from_pretrained will ensure that the model is in eval mode.
if vision_encoder_precision is not None:
vision_encoder = vision_encoder.to(dtype=PRECISION_TO_TYPE[vision_encoder_precision])
vision_encoder.requires_grad_(False)
if device is not None:
vision_encoder = vision_encoder.to(device)
return vision_encoder, vision_encoder_path
def load_image_processor(processor_type, processor_path=None, logger=None):
if processor_path is None:
processor_path = VISION_ENCODER_PATH[processor_type]
if processor_type == "siglip":
processor = SiglipImageProcessor.from_pretrained(processor_path, subfolder="feature_extractor")
else:
raise ValueError(f"Unsupported processor type: {processor_type}")
return processor, processor_path
@dataclass
class VisionEncoderModelOutput(ModelOutput):
"""
Base class for vision encoder model's outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
Last layer hidden-state of the first token of the sequence (classification token)
after further processing through the layers used for the auxiliary pretraining task.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
class VisionEncoder(nn.Module):
def __init__(
self,
vision_encoder_type: str,
vision_encoder_precision: Optional[str] = None,
vision_encoder_path: Optional[str] = None,
processor_type: Optional[str] = None,
processor_path: Optional[str] = None,
output_key: Optional[str] = None,
logger=None,
device=None,
cpu_offload=False,
):
super().__init__()
self.cpu_offload = cpu_offload
self.vision_encoder_type = vision_encoder_type
self.precision = vision_encoder_precision
self.model_path = vision_encoder_path
self.processor_type = processor_type if processor_type is not None else vision_encoder_type
self.processor_path = processor_path if processor_path is not None else vision_encoder_path
self.logger = logger
if "siglip" in vision_encoder_type:
self.output_key = output_key or "last_hidden_state"
else:
raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}")
self.model, self.model_path = load_vision_encoder(
vision_encoder_type=self.vision_encoder_type,
vision_encoder_precision=self.precision,
vision_encoder_path=self.model_path,
logger=self.logger,
device=device,
)
self.dtype = self.model.dtype
self.device = self.model.device
self.processor, self.processor_path = load_image_processor(
processor_type=self.processor_type,
processor_path=self.processor_path,
logger=self.logger,
)
def __repr__(self):
return f"{self.vision_encoder_type} ({self.precision} - {self.model_path})"
def encode_latents_to_images(self, latents, vae, reorg_token=False):
"""
Convert latents to images using VAE decoder.
Args:
latents: Input latents tensor
vae: VAE model for decoding
reorg_token: Whether to reorg the token
Returns:
images: Decoded images as numpy array
"""
# Handle both 4D and 5D latents (for video, take first frame)
first_image_latents = latents[:, :, 0, ...] if len(latents.shape) == 5 else latents
first_image_latents = 1 / vae.config.scaling_factor * first_image_latents
first_image = vae.decode(first_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False)[0].cpu()
first_image = first_image[:, :, 0, :, :]
first_image = (first_image / 2 + 0.5).clamp(0, 1)
first_image = (first_image * 255.0).clamp(0, 255.0)
first_image = first_image.to(torch.uint8).numpy()
first_image = first_image.transpose(0, 2, 3, 1)
assert isinstance(first_image, np.ndarray)
assert first_image.ndim == 4 and first_image.shape[3] == 3
assert first_image.dtype == np.uint8
return first_image
def encode_images(self, images):
"""
Encode images using the vision encoder.
Args:
images: Input images (numpy array or preprocessed tensor)
Returns:
VisionEncoderModelOutput with encoded features
"""
if self.cpu_offload:
self.model = self.model.to("cuda")
self.processor = self.processor.to("cuda")
if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device="cuda", dtype=self.model.dtype)
else:
# Assume already preprocessed
preprocessed = images
outputs = self.model(**preprocessed)
if self.cpu_offload:
self.model = self.model.to("cpu")
self.processor = self.processor.to("cpu")
return VisionEncoderModelOutput(
last_hidden_state=outputs.last_hidden_state,
pooler_output=outputs.pooler_output if hasattr(outputs, "pooler_output") else None,
hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
)
def encode_latents(self, latents, vae, reorg_token=False):
"""
Encode latents by first converting to images, then encoding.
This is the main function that replaces sigclip_vision_encode.
Args:
latents: Input latent tensors
vae: VAE model for decoding latents to images
Returns:
Encoded image features
"""
# Convert latents to images
images = self.encode_latents_to_images(latents, vae, reorg_token)
# Encode images
outputs = self.encode_images(images)
return outputs.last_hidden_state
def forward(self, images):
"""
Forward pass for direct image encoding.
Args:
images: Input images
Returns:
VisionEncoderModelOutput with encoded features
"""
return self.encode_images(images)
class SiglipVisionEncoder:
def __init__(
self,
config,
device=torch.cuda.current_device(),
checkpoint_path=None,
cpu_offload=False,
):
self.config = config
self.device = device
self.cpu_offload = cpu_offload
self.vision_states_dim = 1152
vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip")
self.vision_encoder = VisionEncoder(
vision_encoder_type="siglip",
vision_encoder_precision="fp16",
vision_encoder_path=vision_encoder_path,
processor_type=None,
processor_path=None,
output_key=None,
logger=None,
device=self.device,
cpu_offload=self.cpu_offload,
)
self.vision_in = VisionProjection(in_dim=self.vision_states_dim, out_dim=self.config["hidden_size"], flf_pos_emb=False).to(torch.bfloat16)
vision_in_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"])
safetensors_files = glob.glob(os.path.join(vision_in_model_path, "*.safetensors"))
vision_in_state_dict = {}
for safetensor_path in safetensors_files:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
vision_in_state_dict.update({key.replace("vision_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "vision_in" in key})
self.vision_in.load_state_dict(vision_in_state_dict)
self.vision_in.to(device=device)
@torch.no_grad()
def infer(self, vision_states):
if self.cpu_offload:
self.vision_in = self.vision_in.to("cuda")
vision_states = self.vision_in(vision_states)
if self.cpu_offload:
self.vision_in = self.vision_in.to("cpu")
return vision_states
@torch.no_grad()
def encode_images(self, images):
return self.vision_encoder.encode_images(images)
class VisionProjection(torch.nn.Module):
"""
Projects vision embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py#L488
"""
def __init__(self, in_dim, out_dim, flf_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
@torch.no_grad()
def forward(self, image_embeds):
if hasattr(self, "emb_pos"):
bs, n, d = image_embeds.shape
image_embeds = image_embeds.view(-1, 2 * n, d)
image_embeds = image_embeds + self.emb_pos
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
......@@ -162,7 +162,7 @@ class SglQuantLinearFp8(nn.Module):
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale,
self.weight_scale.float(),
dtype,
bias=self.bias,
)
......@@ -249,7 +249,7 @@ class Q8FQuantLinearInt8(nn.Module):
output_tensor = q8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
......@@ -295,9 +295,9 @@ class Q8FQuantLinearFp8(nn.Module):
output_tensor = fp8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
self.weight_scale.float(),
out_dtype=torch.bfloat16,
)
return output_tensor
......
......@@ -58,12 +58,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.VAE_IMAGE_SIZE = 1024 * 1024
self.cpu_offload = config.get("cpu_offload", False)
self.run_device = self.config.get("run_device", "cuda")
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(self.config.get("run_device", "cuda"))
self.device = torch.device(self.run_device)
self.dtype = torch.bfloat16
self.load()
def load(self):
......@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@torch.no_grad()
def infer(self, text, image_list=None):
if self.cpu_offload:
self.text_encoder.to(self.device)
self.text_encoder.to(self.run_device)
if image_list is not None:
condition_image_list = []
......@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=condition_image_list,
padding=True,
return_tensors="pt",
).to(torch.device(self.device))
).to(torch.device(self.run_device))
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
......@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt = [template.format(e) for e in text]
image_info = {}
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(torch.device(self.run_device))
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
......@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device)
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
......
......@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized: bool = False,
quant_scheme: str = None,
cpu_offload: bool = False,
device=torch.device("cuda"),
run_device=torch.device("cuda"),
):
super().__init__()
self.cpu_offload = cpu_offload
......@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
self.device = torch.device(device)
self.run_device = run_device
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
......@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload:
self.audio_proj.to(self.device)
self.audio_proj.to(self.run_device)
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.device)
x = x + self.audio_pe.to(self.run_device)
if self.cpu_offload:
self.audio_proj.to("cpu")
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment