Unverified Commit 69c2f650 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Remove outdated models (#348)

parent 08d2f46a
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxPostWeights:
def __init__(self, config, mm_type="Default"):
self.config = config
self.mm_type = mm_type
def load_weights(self, weight_dict):
self.norm_out_linear = MM_WEIGHT_REGISTER[self.mm_type]("norm_out.linear.weight", "norm_out.linear.bias")
self.proj_out = MM_WEIGHT_REGISTER[self.mm_type]("proj_out.weight", "proj_out.bias")
self.norm_final = LN_WEIGHT_REGISTER[self.mm_type]("norm_final.weight", "norm_final.bias")
self.norm_out_norm = LN_WEIGHT_REGISTER[self.mm_type]("norm_out.norm.weight", "norm_out.norm.bias", eps=1e-5)
self.weight_list = [self.norm_out_linear, self.proj_out, self.norm_final, self.norm_out_norm]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class CogvideoxPreWeights:
def __init__(self, config):
self.config = config
def load_weights(self, weight_dict):
self.time_embedding_linear_1 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_1.weight", "time_embedding.linear_1.bias")
self.time_embedding_linear_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.linear_2.weight", "time_embedding.linear_2.bias")
self.patch_embed_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.proj.weight", "patch_embed.proj.bias")
self.patch_embed_text_proj = MM_WEIGHT_REGISTER["Default"]("patch_embed.text_proj.weight", "patch_embed.text_proj.bias")
self.weight_list = [self.time_embedding_linear_1, self.time_embedding_linear_2, self.patch_embed_proj, self.patch_embed_text_proj]
for mm_weight in self.weight_list:
mm_weight.set_config(self.config)
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class CogvideoxTransformerWeights:
def __init__(self, config, task="t2v", mm_type="Default"):
self.config = config
self.task = task
self.mm_type = mm_type
self.init()
def init(self):
self.num_layers = self.config["num_layers"]
def load_weights(self, weight_dict):
self.blocks_weights = [CogVideoXBlock(i, self.task, self.mm_type) for i in range(self.num_layers)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
def to_cpu(self):
for block in self.blocks_weights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
block.to_cuda()
class CogVideoXBlock:
def __init__(self, block_index, task="t2v", mm_type="Default"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
def load_weights(self, weight_dict):
self.attn1_to_k = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_k.weight", f"transformer_blocks.{self.block_index}.attn1.to_k.bias")
self.attn1_to_q = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_q.weight", f"transformer_blocks.{self.block_index}.attn1.to_q.bias")
self.attn1_to_v = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_v.weight", f"transformer_blocks.{self.block_index}.attn1.to_v.bias")
self.attn1_to_out = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.to_out.0.weight", f"transformer_blocks.{self.block_index}.attn1.to_out.0.bias")
self.ff_net_0_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.0.proj.weight", f"transformer_blocks.{self.block_index}.ff.net.0.proj.bias")
self.ff_net_2_proj = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.ff.net.2.weight", f"transformer_blocks.{self.block_index}.ff.net.2.bias")
self.norm1_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.linear.weight", f"transformer_blocks.{self.block_index}.norm1.linear.bias")
self.norm2_linear = MM_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.linear.weight", f"transformer_blocks.{self.block_index}.norm2.linear.bias")
self.attn1_norm_k = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_k.weight", f"transformer_blocks.{self.block_index}.attn1.norm_k.bias")
self.attn1_norm_q = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.attn1.norm_q.weight", f"transformer_blocks.{self.block_index}.attn1.norm_q.bias")
self.norm1_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm1.norm.weight", f"transformer_blocks.{self.block_index}.norm1.norm.bias", eps=1e-05)
self.norm2_norm = LN_WEIGHT_REGISTER[self.mm_type](f"transformer_blocks.{self.block_index}.norm2.norm.weight", f"transformer_blocks.{self.block_index}.norm2.norm.bias", eps=1e-05)
self.weight_list = [
self.attn1_to_k,
self.attn1_to_q,
self.attn1_to_v,
self.attn1_to_out,
self.ff_net_0_proj,
self.ff_net_2_proj,
self.norm1_linear,
self.norm2_linear,
self.attn1_norm_k,
self.attn1_norm_q,
self.norm1_norm,
self.norm2_norm,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
mm_weight.to_cuda()
import math
from typing import Dict
import torch
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if current["step"] == 0:
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic["max_order"]):
if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance
else:
break
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])):
output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i)
return output
import torch
class HunyuanPostInfer:
def __init__(self, config):
self.config = config
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, vec):
out = torch.nn.functional.silu(vec)
out = weights.final_layer_adaLN_modulation_1.apply(out)
shift, scale = out.chunk(2, dim=1)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + scale) + shift
out = weights.final_layer_linear.apply(out.to(torch.float32))
_, _, ot, oh, ow = self.scheduler.latents.shape
patch_size = [1, 2, 2]
tt, th, tw = (
ot // patch_size[0],
oh // patch_size[1],
ow // patch_size[2],
)
c = 16
pt, ph, pw = patch_size
out = out.reshape(shape=(1, tt, th, tw, c, pt, ph, pw))
out = torch.einsum("nthwcopq->nctohpwq", out)
out = out.reshape(shape=(1, c, tt * pt, th * ph, tw * pw))
return out
import math
import torch
from einops import rearrange
from lightx2v.utils.envs import *
class HunyuanPreInfer:
def __init__(self, config):
self.heads_num = 24
self.config = config
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, inputs):
x = self.scheduler.latents
t = self.scheduler.timesteps[self.scheduler.step_index]
freqs_cos = self.scheduler.freqs_cos
freqs_sin = self.scheduler.freqs_sin
guidance = self.scheduler.guidance
text_states = inputs["text_encoder_output"]["text_encoder_1_text_states"]
text_mask = inputs["text_encoder_output"]["text_encoder_1_attention_mask"]
text_states_2 = inputs["text_encoder_output"]["text_encoder_2_text_states"]
if self.config["task"] == "i2v":
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.infer_time_in(weights, token_replace_t)
th = x.shape[-2] // 2
tw = x.shape[-1] // 2
frist_frame_token_num = th * tw
time_out = self.infer_time_in(weights, t)
img_out = self.infer_img_in(weights, x)
infer_text_out = self.infer_text_in(weights, text_states, text_mask, t)
infer_vector_out = self.infer_vector_in(weights, text_states_2)
vec = time_out + infer_vector_out
if self.config["task"] == "i2v":
token_replace_vec = token_replace_vec + infer_vector_out
guidance_out = self.infer_guidance_in(weights, guidance)
vec = vec + guidance_out
txt_seq_len = infer_text_out.shape[0]
img_seq_len = img_out.shape[1]
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_seq_len
cu_seqlens_qkv = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_seq_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens_qkv[2 * i + 1] = s1
cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len
if self.config["task"] == "i2v":
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin), token_replace_vec, frist_frame_token_num
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
def infer_time_in(self, weights, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.time_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.time_in_mlp_2.apply(out)
return out
def infer_img_in(self, weights, x):
out = weights.img_in_proj.apply(x)
out = out.flatten(2).transpose(1, 2)
return out
def infer_text_in(self, weights, text_states, text_mask, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
mask_float = text_mask.float().unsqueeze(-1).to(GET_DTYPE()) # [b, s1, 1]
context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations
out = weights.txt_in_c_embedder_linear_1.apply(context_aware_representations)
out = torch.nn.functional.silu(out)
context_aware_representations = weights.txt_in_c_embedder_linear_2.apply(out)
c = timestep_aware_representations + context_aware_representations
txt_in_input_embed = weights.txt_in_input_embedder.apply(text_states[0])
batch_size = text_mask.shape[0]
seq_len = text_mask.shape[1]
self_attn_mask_1 = text_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True
cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1)
normx = weights.txt_in_individual_token_refiner_blocks_0_norm1.apply(txt_in_input_embed)
qkv = weights.txt_in_individual_token_refiner_blocks_0_self_attn_qkv.apply(normx)
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_0_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa
out = weights.txt_in_individual_token_refiner_blocks_0_norm2.apply(out_1)
# mlp
out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc1.apply(out)
out = torch.nn.functional.silu(out)
out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc2.apply(out)
txt_in_input_embed = out_1 + out * gate_mlp
cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1)
normx = weights.txt_in_individual_token_refiner_blocks_1_norm1.apply(txt_in_input_embed)
qkv = weights.txt_in_individual_token_refiner_blocks_1_self_attn_qkv.apply(normx)
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_1_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa
out = weights.txt_in_individual_token_refiner_blocks_1_norm2.apply(out_1)
# mlp
out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc1.apply(out)
out = torch.nn.functional.silu(out)
out = weights.txt_in_individual_token_refiner_blocks_1_mlp_fc2.apply(out)
out = out_1 + out * gate_mlp
return out
def infer_vector_in(self, weights, text_states_2):
out = weights.vector_in_in_layer.apply(text_states_2)
out = torch.nn.functional.silu(out)
out = weights.vector_in_out_layer.apply(out)
return out
def infer_guidance_in(self, weights, guidance):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=guidance.device)
args = guidance.float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype=GET_DTYPE())
out = weights.guidance_in_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out)
out = weights.guidance_in_mlp_2.apply(out)
return out
import torch
from einops import rearrange
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from .utils_bf16 import apply_rotary_emb
class HunyuanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.attention_type = config.get("attention_type", "flash_attn2")
self.double_blocks_num = 20
self.single_blocks_num = 40
self.heads_num = 24
self.hidden_size = 3072
self.mlp_hidden_dim = 12288
self.parallel_attention = None
if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
else:
offload_ratio = 1
self.double_weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.double_blocks_num, offload_ratio=offload_ratio)
self.single_weights_stream_mgr = WeightAsyncStreamManager(blocks_num=self.single_blocks_num, offload_ratio=offload_ratio)
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_without_offload
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[0] = weights.double_blocks[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block(self.double_weights_stream_mgr.active_weights[0], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks)
self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0)
img = img.cpu()
txt = txt.cpu()
del img, txt
torch.cuda.empty_cache()
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(self.single_weights_stream_mgr.active_weights[0], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks)
self.single_weights_stream_mgr.swap_weights()
torch.cuda.empty_cache()
img = x[:img_seq_len, ...]
return img, vec
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block_phase_1(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1)
if token_replace_vec is not None:
token_replace_vec_silu = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_img_mod_out = weights.img_mod.apply(token_replace_vec_silu)
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = token_replace_vec_img_mod_out.chunk(6, dim=-1)
else:
(tr_img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate) = None, None, None, None, None, None
txt_mod_out = weights.txt_mod.apply(vec_silu)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention:
attn = weights.double_attn.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
# world_size = dist.get_world_size()
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img_out = weights.img_attn_proj.apply(img_attn)
txt_out = weights.txt_attn_proj.apply(txt_attn)
return (
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
def infer_double_block_phase_2(
self,
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
if tr_img_mod1_gate is not None:
x_zero = img_out[:frist_frame_token_num] * tr_img_mod1_gate
x_orig = img_out[frist_frame_token_num:] * img_mod1_gate
img_out = torch.concat((x_zero, x_orig), dim=0)
else:
img_out = img_out * img_mod1_gate
img = img + img_out
img_out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
if tr_img_mod1_gate is not None:
x_zero = img_out[:frist_frame_token_num] * (1 + tr_img_mod2_scale) + tr_img_mod2_shift
x_orig = img_out[frist_frame_token_num:] * (1 + img_mod2_scale) + img_mod2_shift
img_out = torch.concat((x_zero, x_orig), dim=0)
else:
img_out = img_out * (1 + img_mod2_scale) + img_mod2_shift
img_out = weights.img_mlp_fc1.apply(img_out)
img_out = torch.nn.functional.gelu(img_out, approximate="tanh")
img_out = weights.img_mlp_fc2.apply(img_out)
txt_out = txt_out * txt_mod1_gate
txt = txt + txt_out
txt_out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
txt_out = txt_out * (1 + txt_mod2_scale) + txt_mod2_shift
txt_out = weights.txt_mlp_fc1.apply(txt_out)
txt_out = torch.nn.functional.gelu(txt_out, approximate="tanh")
txt_out = weights.txt_mlp_fc2.apply(txt_out)
return img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate
def infer_double_block_phase_3(self, img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt):
# img
img_out = img_out * img_mod2_gate
img = img + img_out
# txt
txt_out = txt_out * txt_mod2_gate
txt = txt + txt_out
return img, txt
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
if tr_img_mod1_scale is not None:
x_zero = img_modulated[:frist_frame_token_num] * (1 + tr_img_mod1_scale) + tr_img_mod1_shift
x_orig = img_modulated[frist_frame_token_num:] * (1 + img_mod1_scale) + img_mod1_shift
img_modulated = torch.concat((x_zero, x_orig), dim=0)
else:
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k)
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v
def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift):
txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k)
return txt_q, txt_k, txt_v
def infer_single_block_phase_1(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
if token_replace_vec is not None:
token_replace_vec_out = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_out = weights.modulation.apply(token_replace_vec_out)
tr_mod_shift, tr_mod_scale, tr_mod_gate = token_replace_vec_out.chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_mod_scale) + tr_mod_shift
x_orig = out[frist_frame_token_num:] * (1 + mod_scale) + mod_shift
x_mod = torch.concat((x_zero, x_orig), dim=0)
else:
x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q)
k = weights.k_norm.apply(k)
img_q, txt_q = q[:-txt_seq_len, :, :], q[-txt_seq_len:, :, :]
img_k, txt_k = k[:-txt_seq_len, :, :], k[-txt_seq_len:, :, :]
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention:
attn = weights.single_attn.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
else:
attn = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out)
return out, mod_gate, tr_mod_gate
def infer_single_block_phase_2(self, x, out, tr_mod_gate, mod_gate, token_replace_vec=None, frist_frame_token_num=None):
if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * tr_mod_gate
x_orig = out[frist_frame_token_num:] * mod_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * mod_gate
x = x + out
return x
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
return x
import sgl_kernel
def rms_norm(x, weight, eps):
x = x.contiguous()
orig_shape = x.shape
x = x.view(-1, orig_shape[-1])
x = sgl_kernel.rmsnorm(x, weight, eps).view(orig_shape)
return x
from typing import Tuple, Union
import torch
def rms_norm(x, weight, eps):
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x * weight
return x
def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq, shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk, shape_0, shape_1, cos, sin)
return xq_out, xk_out
from typing import Tuple, Union
import torch
from lightx2v.utils.envs import *
def rms_norm(x, weight, eps):
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.to(GET_DTYPE())
x = x * weight
return x
def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out.to(GET_DTYPE())
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin)
return xq_out, xk_out
import json
import os
import torch
from loguru import logger
from safetensors import safe_open
from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import (
HunyuanTransformerInferAdaCaching,
HunyuanTransformerInferCustomCaching,
HunyuanTransformerInferTaylorCaching,
HunyuanTransformerInferTeaCaching,
)
from lightx2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.utils.envs import *
class HunyuanModel:
pre_weight_class = HunyuanPreWeights
post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config, device, args):
self.model_path = model_path
self.config = config
self.device = device
self.args = args
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
self._init_infer_class()
self._init_weights()
self._init_infer()
if self.config["cpu_offload"]:
self.to_cpu()
def _load_ckpt(self):
if self.args.task == "t2v":
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
else:
ckpt_path = os.path.join(self.model_path, "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict
def _load_quant_ckpt(self):
ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
if ckpt_path.endswith(".pth"):
logger.info(f"Loading {ckpt_path} as PyTorch model.")
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)
else:
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No .pth file or *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
with open(index_path, "r") as f:
index_data = json.load(f)
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
weight_dict[k] = f.get_tensor(k)
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(GET_DTYPE())
return weight_dict
def _init_weights(self):
if not self.dit_quantized or self.weight_auto_quant:
weight_dict = self._load_ckpt()
else:
weight_dict = self._load_quant_ckpt()
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
self.pre_weight.load(weight_dict)
self.post_weight.load(weight_dict)
self.transformer_weights.load(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def save_weights(self, save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)
pre_state_dict = self.pre_weight.state_dict()
logger.info(pre_state_dict.keys())
post_state_dict = self.post_weight.state_dict()
logger.info(post_state_dict.keys())
transformer_state_dict = self.transformer_weights.state_dict()
logger.info(transformer_state_dict.keys())
save_dict = {}
save_dict.update(pre_state_dict)
save_dict.update(post_state_dict)
save_dict.update(transformer_state_dict)
save_path = os.path.join(save_path, "quant_weights.pth")
torch.save(save_dict, save_path)
logger.info(f"Save weights to {save_path}")
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
inputs = self.pre_infer.infer(self.pre_weight, inputs)
inputs = self.transformer_infer.infer(self.transformer_weights, *inputs)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, *inputs)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferTaylorCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
elif self.config["feature_caching"] == "Ada":
self.transformer_infer_class = HunyuanTransformerInferAdaCaching
elif self.config["feature_caching"] == "Custom":
self.transformer_infer_class = HunyuanTransformerInferCustomCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
class HunyuanPostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.add_module("final_layer_linear", MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias"))
self.add_module("final_layer_adaLN_modulation_1", MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias"))
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
class HunyuanPreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.add_module("img_in_proj", CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2)))
self.add_module("txt_in_input_embedder", MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias"))
self.add_module("txt_in_t_embedder_mlp_0", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias"))
self.add_module("txt_in_t_embedder_mlp_2", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias"))
self.add_module("txt_in_c_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias"))
self.add_module("txt_in_c_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias"))
self.add_module(
"txt_in_individual_token_refiner_blocks_0_norm1",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_self_attn_proj",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_norm2",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_mlp_fc1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_mlp_fc2",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_norm1",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_self_attn_proj",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_norm2",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_mlp_fc1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_mlp_fc2",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"),
)
self.add_module(
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"),
)
self.add_module("time_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias"))
self.add_module("time_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias"))
self.add_module("vector_in_in_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias"))
self.add_module("vector_in_out_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias"))
self.add_module("guidance_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias"))
self.add_module("guidance_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias"))
# attention weights section
self.add_module("txt_in_attn_1", ATTN_WEIGHT_REGISTER["torch_sdpa"]())
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
class HunyuanTransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.double_blocks_num = 20
self.single_blocks_num = 40
self.add_module("double_blocks", WeightModuleList([HunyuanTransformerDoubleBlock(i, self.config) for i in range(self.double_blocks_num)]))
self.add_module("single_blocks", WeightModuleList([HunyuanTransformerSingleBlock(i, self.config) for i in range(self.single_blocks_num)]))
class HunyuanTransformerDoubleBlock(WeightModule):
def __init__(self, block_index, config):
super().__init__()
self.block_index = block_index
self.config = config
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.add_module("img_mod", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mod.linear.weight", f"double_blocks.{self.block_index}.img_mod.linear.bias"))
self.add_module("img_attn_qkv", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_qkv.weight", f"double_blocks.{self.block_index}.img_attn_qkv.bias"))
self.add_module("img_attn_q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_q_norm.weight", eps=1e-6))
self.add_module("img_attn_k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_k_norm.weight", eps=1e-6))
self.add_module("img_attn_proj", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_proj.weight", f"double_blocks.{self.block_index}.img_attn_proj.bias"))
self.add_module("img_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc1.weight", f"double_blocks.{self.block_index}.img_mlp.fc1.bias"))
self.add_module("img_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc2.weight", f"double_blocks.{self.block_index}.img_mlp.fc2.bias"))
self.add_module("txt_mod", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mod.linear.weight", f"double_blocks.{self.block_index}.txt_mod.linear.bias"))
self.add_module("txt_attn_qkv", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_qkv.weight", f"double_blocks.{self.block_index}.txt_attn_qkv.bias"))
self.add_module("txt_attn_q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_q_norm.weight", eps=1e-6))
self.add_module("txt_attn_k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_k_norm.weight", eps=1e-6))
self.add_module("txt_attn_proj", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_proj.weight", f"double_blocks.{self.block_index}.txt_attn_proj.bias"))
self.add_module("txt_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias"))
self.add_module("txt_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias"))
# attention weights section
self.add_module("double_attn", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
class HunyuanTransformerSingleBlock(WeightModule):
def __init__(self, block_index, config):
super().__init__()
self.block_index = block_index
self.config = config
self.sparge = config.get("sparge", False)
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.add_module("linear1", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear1.weight", f"single_blocks.{self.block_index}.linear1.bias"))
self.add_module("linear2", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear2.weight", f"single_blocks.{self.block_index}.linear2.bias"))
self.add_module("q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6))
self.add_module("k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6))
self.add_module("modulation", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias"))
# attention weights section
if self.sparge:
# load sparge attention weights
#! todo
pass
else:
self.add_module("single_attn", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment