Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
{
"infer_steps": 50,
"transformer_model_name": "480p_t2v",
"fps": 24,
"target_video_length": 121,
"aspect_ratio": "16:9",
"vae_stride": [4, 16, 16],
"sample_shift": 7.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "sage_attn2",
"cpu_offload": true,
"offload_granularity": "block",
"vae_cpu_offload": false,
"byt5_cpu_offload": false,
"qwen25vl_cpu_offload": true,
"siglip_cpu_offload": false
}
{
"infer_steps": 50,
"transformer_model_name": "480p_t2v",
"fps": 24,
"target_video_length": 121,
"aspect_ratio": "16:9",
"vae_stride": [4, 16, 16],
"sample_shift": 7.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "flash_attn3",
"dit_quantized_ckpt": "/path/to/quant_model.safetensors",
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl"
}
{
"infer_steps": 50,
"transformer_model_name": "480p_i2v",
"fps": 24,
"target_video_length": 121,
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": 6.0,
"enable_cfg": true,
"attn_type": "flash_attn3",
"video_super_resolution": {
"sr_version": "720p_sr_distilled",
"flow_shift": 2.0,
"base_resolution": "480p",
"guidance_scale": 1.0,
"num_inference_steps": 6,
"use_meanflow": true
}
}
......@@ -46,6 +46,9 @@ class FlashAttn2Weight(AttnWeightTemplate):
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func(
q,
k,
......@@ -78,6 +81,9 @@ class FlashAttn3Weight(AttnWeightTemplate):
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
x = flash_attn_varlen_func_v3(
q,
k,
......
......@@ -27,6 +27,11 @@ class UlyssesAttnWeight(AttnWeightTemplate):
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
......@@ -134,9 +139,16 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
......@@ -181,22 +193,23 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
# 异步发起通信后同步
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_rank, group=seq_p_group)
recvq_req = dist.irecv(gathered_q_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvk_req = dist.irecv(gathered_k_shards[target_rank], src=target_global_rank, group=seq_p_group)
recvv_req = dist.irecv(gathered_v_shards[target_rank], src=target_global_rank, group=seq_p_group)
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_req = dist.isend(v_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_req.wait()
sendk_req.wait()
sendv_req.wait()
......@@ -254,6 +267,9 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group):
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
......@@ -269,14 +285,15 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
# 异步发起通信后同步
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank:
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
else:
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
......
......@@ -809,7 +809,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
self.weight_scale,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("int8-q8f")
......@@ -840,7 +840,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
fuse_gelu=False,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
......
......@@ -6,6 +6,8 @@ import torch
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from .triton_ops import norm_infer
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
......@@ -165,3 +167,30 @@ class LNWeight(LNWeightTemplate):
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor
@LN_WEIGHT_REGISTER("Triton")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
if self.weight_name is not None:
if not torch._dynamo.is_compiling():
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE()).pin_memory()
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
else:
self.weight = None
if self.bias_name is not None:
if not torch._dynamo.is_compiling():
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE()).pin_memory()
else:
self.bias = self.lazy_load_file.get_tensor(self.bias_name).to(GET_DTYPE())
else:
self.bias = None
def apply(self, input_tensor):
input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps)
return input_tensor
......@@ -80,14 +80,14 @@ class RMSWeight(RMSWeightTemplate):
else:
self.weight = self.lazy_load_file.get_tensor(self.weight_name).to(GET_DTYPE())
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = (input_tensor * self.weight).to(GET_DTYPE())
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
return input_tensor
def state_dict(self, destination=None):
......@@ -111,7 +111,15 @@ class RMSWeight(RMSWeightTemplate):
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
):
super().__init__(weight_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def load_from_disk(self):
......
This diff is collapsed.
......@@ -5,6 +5,7 @@ import torch.distributed as dist
from loguru import logger
from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
......@@ -49,6 +50,7 @@ def main():
"wan2.2_moe_distill",
"qwen_image",
"wan2.2_animate",
"hunyuan_video_1.5",
],
default="wan2.1",
)
......
import json
def closest_color(requested_color):
import webcolors
min_colors = {}
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - requested_color[0]) ** 2
gd = (g_c - requested_color[1]) ** 2
bd = (b_c - requested_color[2]) ** 2
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def convert_rgb_to_names(rgb_tuple):
try:
import webcolors
color_name = webcolors.rgb_to_name(rgb_tuple)
except ValueError:
color_name = closest_color(rgb_tuple)
return color_name
class MultilingualPromptFormat:
def __init__(
self,
font_path: str = "assets/glyph_sdxl_assets/multilingual_10-lang_idx.json",
color_path: str = "assets/glyph_sdxl_assets/color_idx.json",
):
with open(font_path, "r") as f:
self.font_dict = json.load(f)
with open(color_path, "r") as f:
self.color_dict = json.load(f)
def format_prompt(self, texts, styles):
"""
Text "{text}" in {color}, {type}.
"""
prompt = ""
for text, style in zip(texts, styles):
text_prompt = f'Text "{text}"'
attr_list = []
# format color
if style["color"] is not None:
import webcolors
hex_color = style["color"]
rgb_color = webcolors.hex_to_rgb(hex_color)
color_name = convert_rgb_to_names(rgb_color)
attr_list.append(f"<color-{self.color_dict[color_name]}>")
# format font
if style["font-family"] is not None:
attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>")
attr_suffix = ", ".join(attr_list)
text_prompt += " in " + attr_suffix
text_prompt += ". "
else:
text_prompt += ". "
prompt = prompt + text_prompt
return prompt
import glob
import json
import os
import re
import torch
import torch.nn as nn
from safetensors import safe_open
from transformers import AutoTokenizer, T5ForConditionalGeneration
from .format_prompt import MultilingualPromptFormat
def add_special_token(
tokenizer,
text_encoder,
add_color,
add_font,
color_ann_path,
font_ann_path,
multilingual=False,
):
"""
Add special tokens for color and font to tokenizer and text encoder.
Args:
tokenizer: Huggingface tokenizer.
text_encoder: Huggingface T5 encoder.
add_color (bool): Whether to add color tokens.
add_font (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
multilingual (bool): Whether to use multilingual font tokens.
"""
with open(font_ann_path, "r") as f:
idx_font_dict = json.load(f)
with open(color_ann_path, "r") as f:
idx_color_dict = json.load(f)
if multilingual:
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
else:
font_token = [f"<font-{i}>" for i in range(len(idx_font_dict))]
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
additional_special_tokens = []
if add_color:
additional_special_tokens += color_token
if add_font:
additional_special_tokens += font_token
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
def load_byt5_and_byt5_tokenizer(
byt5_name="google/byt5-small",
special_token=False,
color_special_token=False,
font_special_token=False,
color_ann_path="assets/color_idx.json",
font_ann_path="assets/font_idx_512.json",
huggingface_cache_dir=None,
multilingual=False,
device=None,
):
"""
Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed.
Args:
byt5_name (str): Model name or path.
special_token (bool): Whether to add special tokens.
color_special_token (bool): Whether to add color tokens.
font_special_token (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
huggingface_cache_dir (str): Huggingface cache directory.
multilingual (bool): Whether to use multilingual font tokens.
device (str or torch.device): Device to load the model onto.
Returns:
tuple: (byt5_text_encoder, byt5_tokenizer)
"""
byt5_tokenizer = AutoTokenizer.from_pretrained(
byt5_name,
cache_dir=huggingface_cache_dir,
)
byt5_text_encoder = T5ForConditionalGeneration.from_pretrained(
byt5_name,
cache_dir=huggingface_cache_dir,
).get_encoder()
if "cuda" not in str(device):
device = torch.device(device)
else:
device = torch.device(device)
byt5_text_encoder = byt5_text_encoder.to(device)
if special_token:
add_special_token(
byt5_tokenizer,
byt5_text_encoder,
add_color=color_special_token,
add_font=font_special_token,
color_ann_path=color_ann_path,
font_ann_path=font_ann_path,
multilingual=multilingual,
)
return byt5_text_encoder, byt5_tokenizer
class ByT5Mapper(nn.Module):
"""
ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection.
Args:
in_dim (int): Input dimension (must equal out_dim if use_residual).
out_dim (int): Output dimension after second linear layer.
hidden_dim (int): Hidden dimension for intermediate layer.
out_dim1 (int): Final output dimension.
use_residual (bool): Whether to use residual connection (default: True).
"""
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.fc3 = nn.Linear(out_dim, out_dim1)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
"""
Forward pass for ByT5Mapper.
Args:
x (Tensor): Input tensor of shape (..., in_dim).
Returns:
Tensor: Output tensor of shape (..., out_dim1).
"""
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_residual:
x2 = x2 + residual
return x2
class ByT5TextEncoder:
def __init__(
self,
config,
device=torch.cuda.current_device(),
checkpoint_path=None,
byt5_max_length=256,
cpu_offload=False,
):
self.cpu_offload = cpu_offload
self.config = config
self.device = device
self.byt5_max_length = byt5_max_length
self.enable_cfg = config.get("enable_cfg", False)
byT5_google_path = os.path.join(checkpoint_path, "text_encoder", "byt5-small")
byT5_ckpt_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "checkpoints/byt5_model.pt")
multilingual_prompt_format_color_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/color_idx.json")
multilingual_prompt_format_font_path = os.path.join(checkpoint_path, "text_encoder", "Glyph-SDXL-v2", "assets/multilingual_10-lang_idx.json")
byt5_args = dict(
byT5_google_path=byT5_google_path,
byT5_ckpt_path=byT5_ckpt_path,
multilingual_prompt_format_color_path=multilingual_prompt_format_color_path,
multilingual_prompt_format_font_path=multilingual_prompt_format_font_path,
byt5_max_length=byt5_max_length,
)
self.byt5_tokenizer, self.byt5_model, self.byt5_max_length = self.create_byt5(byt5_args, device)
self.byt5_model = self.byt5_model.to(device=device)
self.prompt_format = MultilingualPromptFormat(font_path=multilingual_prompt_format_font_path, color_path=multilingual_prompt_format_color_path)
self.byt5_mapper = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.config["hidden_size"], use_residual=False).to(torch.bfloat16)
byt5_mapper_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"])
safetensors_files = glob.glob(os.path.join(byt5_mapper_model_path, "*.safetensors"))
byt5_mapper_state_dict = {}
for safetensor_path in safetensors_files:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
byt5_mapper_state_dict.update({key.replace("byt5_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "byt5_in" in key})
self.byt5_mapper.load_state_dict(byt5_mapper_state_dict)
self.byt5_mapper.to(device=device)
def create_byt5(self, args, device):
"""
Create ByT5 tokenizer and encoder, load weights if provided.
Args:
args (dict): Configuration dictionary.
device (str or torch.device): Device to load the model onto.
Returns:
tuple: (byt5_tokenizer, byt5_model, byt5_max_length)
"""
byt5_max_length = args["byt5_max_length"]
byt5_config = dict(
byt5_name=args["byT5_google_path"],
special_token=True,
color_special_token=True,
font_special_token=True,
color_ann_path=args["multilingual_prompt_format_color_path"],
font_ann_path=args["multilingual_prompt_format_font_path"],
multilingual=True,
)
huggingface_cache_dir = None
byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer(
**byt5_config,
huggingface_cache_dir=huggingface_cache_dir,
device=device,
)
# Load custom checkpoint if provided
if args["byT5_ckpt_path"] is not None:
if "cuda" not in str(device):
byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device)
else:
byt5_state_dict = torch.load(args["byT5_ckpt_path"], map_location=device)
if "state_dict" in byt5_state_dict:
sd = byt5_state_dict["state_dict"]
newsd = {}
for k, v in sd.items():
if k.startswith("module.text_tower.encoder."):
newsd[k[len("module.text_tower.encoder.") :]] = v
byt5_state_dict = newsd
byt5_model.load_state_dict(byt5_state_dict)
byt5_model.requires_grad_(False)
return byt5_tokenizer, byt5_model, byt5_max_length
def _extract_glyph_texts(self, prompt):
"""
Extract glyph texts from prompt using regex pattern.
Args:
prompt: Input prompt string
Returns:
List of extracted glyph texts
"""
pattern = r"\"(.*?)\"|“(.*?)”"
matches = re.findall(pattern, prompt)
result = [match[0] or match[1] for match in matches]
result = list(dict.fromkeys(result)) if len(result) > 1 else result
return result
def _process_single_byt5_prompt(self, prompt_text, device):
"""
Process a single prompt for byT5 encoding.
Args:
prompt_text: The prompt text to process
device: Target device for tensors
Returns:
Tuple of (byt5_embeddings, byt5_mask)
"""
byt5_embeddings = torch.zeros((1, self.byt5_max_length, 1472), device=device)
byt5_mask = torch.zeros((1, self.byt5_max_length), device=device, dtype=torch.int64)
glyph_texts = self._extract_glyph_texts(prompt_text)
if len(glyph_texts) > 0:
text_styles = [{"color": None, "font-family": None} for _ in range(len(glyph_texts))]
formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles)
text_ids, text_mask = self.get_byt5_text_tokens(self.byt5_tokenizer, self.byt5_max_length, formatted_text)
text_ids = text_ids.to("cuda")
text_mask = text_mask.to("cuda")
byt5_outputs = self.byt5_model(text_ids, attention_mask=text_mask.float())
byt5_embeddings = byt5_outputs[0]
byt5_mask = text_mask
return byt5_embeddings, byt5_mask
def _prepare_byt5_embeddings(self, prompts):
if isinstance(prompts, str):
prompt_list = [prompts]
elif isinstance(prompts, list):
prompt_list = prompts
else:
raise ValueError("prompts must be str or list of str")
positive_embeddings = []
positive_masks = []
negative_embeddings = []
negative_masks = []
for prompt in prompt_list:
pos_emb, pos_mask = self._process_single_byt5_prompt(prompt, "cuda")
positive_embeddings.append(pos_emb)
positive_masks.append(pos_mask)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
neg_emb, neg_mask = self._process_single_byt5_prompt("", "cuda")
negative_embeddings.append(neg_emb)
negative_masks.append(neg_mask)
byt5_positive = torch.cat(positive_embeddings, dim=0)
byt5_positive_mask = torch.cat(positive_masks, dim=0)
if self.enable_cfg: # TODO: 把cfg拆出去,更适合并行
byt5_negative = torch.cat(negative_embeddings, dim=0)
byt5_negative_mask = torch.cat(negative_masks, dim=0)
byt5_embeddings = torch.cat([byt5_negative, byt5_positive], dim=0)
byt5_masks = torch.cat([byt5_negative_mask, byt5_positive_mask], dim=0)
else:
byt5_embeddings = byt5_positive
byt5_masks = byt5_positive_mask
return byt5_embeddings, byt5_masks
@torch.no_grad()
def infer(self, prompts):
if self.cpu_offload:
self.byt5_model = self.byt5_model.to("cuda")
self.byt5_mapper = self.byt5_mapper.to("cuda")
byt5_embeddings, byt5_masks = self._prepare_byt5_embeddings(prompts)
byt5_features = self.byt5_mapper(byt5_embeddings.to(torch.bfloat16))
if self.cpu_offload:
self.byt5_model = self.byt5_model.to("cpu")
self.byt5_mapper = self.byt5_mapper.to("cpu")
return byt5_features, byt5_masks
if __name__ == "__main__":
byt5 = ByT5TextEncoder(config={"transformer_model_name": "480p_t2v", "hidden_size": 2048}, device="cuda", checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5")
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
byt5_features, byt5_masks = byt5.infer(prompt)
print(byt5_features.shape, byt5_features.sum())
print(byt5_masks.shape, byt5_masks.sum())
This diff is collapsed.
import glob
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from safetensors.torch import safe_open
from transformers import SiglipImageProcessor, SiglipVisionModel
from transformers.utils import ModelOutput
PRECISION_TO_TYPE = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VISION_ENCODER_PATH = {}
def use_default(value, default):
return value if value is not None else default
def load_vision_encoder(
vision_encoder_type,
vision_encoder_precision=None,
vision_encoder_path=None,
logger=None,
device=None,
):
if vision_encoder_path is None:
vision_encoder_path = VISION_ENCODER_PATH[vision_encoder_type]
if vision_encoder_type == "siglip":
vision_encoder = SiglipVisionModel.from_pretrained(vision_encoder_path, subfolder="image_encoder")
else:
raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}")
# from_pretrained will ensure that the model is in eval mode.
if vision_encoder_precision is not None:
vision_encoder = vision_encoder.to(dtype=PRECISION_TO_TYPE[vision_encoder_precision])
vision_encoder.requires_grad_(False)
if device is not None:
vision_encoder = vision_encoder.to(device)
return vision_encoder, vision_encoder_path
def load_image_processor(processor_type, processor_path=None, logger=None):
if processor_path is None:
processor_path = VISION_ENCODER_PATH[processor_type]
if processor_type == "siglip":
processor = SiglipImageProcessor.from_pretrained(processor_path, subfolder="feature_extractor")
else:
raise ValueError(f"Unsupported processor type: {processor_type}")
return processor, processor_path
@dataclass
class VisionEncoderModelOutput(ModelOutput):
"""
Base class for vision encoder model's outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
Last layer hidden-state of the first token of the sequence (classification token)
after further processing through the layers used for the auxiliary pretraining task.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
class VisionEncoder(nn.Module):
def __init__(
self,
vision_encoder_type: str,
vision_encoder_precision: Optional[str] = None,
vision_encoder_path: Optional[str] = None,
processor_type: Optional[str] = None,
processor_path: Optional[str] = None,
output_key: Optional[str] = None,
logger=None,
device=None,
cpu_offload=False,
):
super().__init__()
self.cpu_offload = cpu_offload
self.vision_encoder_type = vision_encoder_type
self.precision = vision_encoder_precision
self.model_path = vision_encoder_path
self.processor_type = processor_type if processor_type is not None else vision_encoder_type
self.processor_path = processor_path if processor_path is not None else vision_encoder_path
self.logger = logger
if "siglip" in vision_encoder_type:
self.output_key = output_key or "last_hidden_state"
else:
raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}")
self.model, self.model_path = load_vision_encoder(
vision_encoder_type=self.vision_encoder_type,
vision_encoder_precision=self.precision,
vision_encoder_path=self.model_path,
logger=self.logger,
device=device,
)
self.dtype = self.model.dtype
self.device = self.model.device
self.processor, self.processor_path = load_image_processor(
processor_type=self.processor_type,
processor_path=self.processor_path,
logger=self.logger,
)
def __repr__(self):
return f"{self.vision_encoder_type} ({self.precision} - {self.model_path})"
def encode_latents_to_images(self, latents, vae, reorg_token=False):
"""
Convert latents to images using VAE decoder.
Args:
latents: Input latents tensor
vae: VAE model for decoding
reorg_token: Whether to reorg the token
Returns:
images: Decoded images as numpy array
"""
# Handle both 4D and 5D latents (for video, take first frame)
first_image_latents = latents[:, :, 0, ...] if len(latents.shape) == 5 else latents
first_image_latents = 1 / vae.config.scaling_factor * first_image_latents
first_image = vae.decode(first_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False)[0].cpu()
first_image = first_image[:, :, 0, :, :]
first_image = (first_image / 2 + 0.5).clamp(0, 1)
first_image = (first_image * 255.0).clamp(0, 255.0)
first_image = first_image.to(torch.uint8).numpy()
first_image = first_image.transpose(0, 2, 3, 1)
assert isinstance(first_image, np.ndarray)
assert first_image.ndim == 4 and first_image.shape[3] == 3
assert first_image.dtype == np.uint8
return first_image
def encode_images(self, images):
"""
Encode images using the vision encoder.
Args:
images: Input images (numpy array or preprocessed tensor)
Returns:
VisionEncoderModelOutput with encoded features
"""
if self.cpu_offload:
self.model = self.model.to("cuda")
self.processor = self.processor.to("cuda")
if isinstance(images, np.ndarray):
# Preprocess images if they're numpy arrays
preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to(device="cuda", dtype=self.model.dtype)
else:
# Assume already preprocessed
preprocessed = images
outputs = self.model(**preprocessed)
if self.cpu_offload:
self.model = self.model.to("cpu")
self.processor = self.processor.to("cpu")
return VisionEncoderModelOutput(
last_hidden_state=outputs.last_hidden_state,
pooler_output=outputs.pooler_output if hasattr(outputs, "pooler_output") else None,
hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
)
def encode_latents(self, latents, vae, reorg_token=False):
"""
Encode latents by first converting to images, then encoding.
This is the main function that replaces sigclip_vision_encode.
Args:
latents: Input latent tensors
vae: VAE model for decoding latents to images
Returns:
Encoded image features
"""
# Convert latents to images
images = self.encode_latents_to_images(latents, vae, reorg_token)
# Encode images
outputs = self.encode_images(images)
return outputs.last_hidden_state
def forward(self, images):
"""
Forward pass for direct image encoding.
Args:
images: Input images
Returns:
VisionEncoderModelOutput with encoded features
"""
return self.encode_images(images)
class SiglipVisionEncoder:
def __init__(
self,
config,
device=torch.cuda.current_device(),
checkpoint_path=None,
cpu_offload=False,
):
self.config = config
self.device = device
self.cpu_offload = cpu_offload
self.vision_states_dim = 1152
vision_encoder_path = os.path.join(checkpoint_path, "vision_encoder", "siglip")
self.vision_encoder = VisionEncoder(
vision_encoder_type="siglip",
vision_encoder_precision="fp16",
vision_encoder_path=vision_encoder_path,
processor_type=None,
processor_path=None,
output_key=None,
logger=None,
device=self.device,
cpu_offload=self.cpu_offload,
)
self.vision_in = VisionProjection(in_dim=self.vision_states_dim, out_dim=self.config["hidden_size"], flf_pos_emb=False).to(torch.bfloat16)
vision_in_model_path = os.path.join(checkpoint_path, "transformer", self.config["transformer_model_name"])
safetensors_files = glob.glob(os.path.join(vision_in_model_path, "*.safetensors"))
vision_in_state_dict = {}
for safetensor_path in safetensors_files:
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
vision_in_state_dict.update({key.replace("vision_in.", ""): f.get_tensor(key).to(torch.bfloat16) for key in f.keys() if "vision_in" in key})
self.vision_in.load_state_dict(vision_in_state_dict)
self.vision_in.to(device=device)
@torch.no_grad()
def infer(self, vision_states):
if self.cpu_offload:
self.vision_in = self.vision_in.to("cuda")
vision_states = self.vision_in(vision_states)
if self.cpu_offload:
self.vision_in = self.vision_in.to("cpu")
return vision_states
@torch.no_grad()
def encode_images(self, images):
return self.vision_encoder.encode_images(images)
class VisionProjection(torch.nn.Module):
"""
Projects vision embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py#L488
"""
def __init__(self, in_dim, out_dim, flf_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
@torch.no_grad()
def forward(self, image_embeds):
if hasattr(self, "emb_pos"):
bs, n, d = image_embeds.shape
image_embeds = image_embeds.view(-1, 2 * n, d)
image_embeds = image_embeds + self.emb_pos
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
......@@ -162,7 +162,7 @@ class SglQuantLinearFp8(nn.Module):
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale,
self.weight_scale.float(),
dtype,
bias=self.bias,
)
......@@ -249,7 +249,7 @@ class Q8FQuantLinearInt8(nn.Module):
output_tensor = q8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale.float(),
fuse_gelu=False,
......@@ -295,9 +295,9 @@ class Q8FQuantLinearFp8(nn.Module):
output_tensor = fp8_linear(
input_tensor_quant,
self.weight,
self.bias if self.bias is not None else None,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
self.weight_scale.float(),
out_dtype=torch.bfloat16,
)
return output_tensor
......
......@@ -58,12 +58,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.VAE_IMAGE_SIZE = 1024 * 1024
self.cpu_offload = config.get("cpu_offload", False)
self.run_device = self.config.get("run_device", "cuda")
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(self.config.get("run_device", "cuda"))
self.device = torch.device(self.run_device)
self.dtype = torch.bfloat16
self.load()
def load(self):
......@@ -95,7 +95,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@torch.no_grad()
def infer(self, text, image_list=None):
if self.cpu_offload:
self.text_encoder.to(self.device)
self.text_encoder.to(self.run_device)
if image_list is not None:
condition_image_list = []
......@@ -130,7 +130,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images=condition_image_list,
padding=True,
return_tensors="pt",
).to(torch.device(self.device))
).to(torch.device(self.run_device))
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
......@@ -153,7 +153,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt = [template.format(e) for e in text]
image_info = {}
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(torch.device(self.run_device))
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
......@@ -169,7 +169,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.run_device)
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
......
......@@ -252,7 +252,7 @@ class AudioAdapter(nn.Module):
quantized: bool = False,
quant_scheme: str = None,
cpu_offload: bool = False,
device=torch.device("cuda"),
run_device=torch.device("cuda"),
):
super().__init__()
self.cpu_offload = cpu_offload
......@@ -263,7 +263,7 @@ class AudioAdapter(nn.Module):
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
self.device = torch.device(device)
self.run_device = run_device
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
......@@ -302,10 +302,10 @@ class AudioAdapter(nn.Module):
@torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame):
if self.cpu_offload:
self.audio_proj.to(self.device)
self.audio_proj.to(self.run_device)
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe.to(self.device)
x = x + self.audio_pe.to(self.run_device)
if self.cpu_offload:
self.audio_proj.to("cpu")
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment