Commit daf4c74e authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

first commit

parent 6c79160f
import os
import torch
from lightx2v.text2v.models.networks.hunyuan.weights.pre_weights import HunyuanPreWeights
from lightx2v.text2v.models.networks.hunyuan.weights.post_weights import HunyuanPostWeights
from lightx2v.text2v.models.networks.hunyuan.weights.transformer_weights import HunyuanTransformerWeights
from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreInfer
from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching
# from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_hunyuan
class HunyuanModel:
pre_weight_class = HunyuanPreWeights
post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config):
self.model_path = model_path
self.config = config
self._init_infer_class()
self._init_weights()
self._init_infer()
if self.config['parallel_attn']:
parallelize_hunyuan(self)
if self.config['cpu_offload']:
self.to_cpu()
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config['feature_caching'] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config['feature_caching'] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferFeatureCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_ckpt(self):
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location="cuda", weights_only=True)["module"]
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
self.pre_weight.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class()
self.post_infer = self.post_infer_class()
self.transformer_infer = self.transformer_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, text_encoder_output, image_encoder_output, args):
pre_infer_out = self.pre_infer.infer(
self.pre_weight,
self.scheduler.latents,
self.scheduler.timesteps[self.scheduler.step_index],
text_encoder_output["text_encoder_1_text_states"],
text_encoder_output["text_encoder_1_attention_mask"],
text_encoder_output["text_encoder_2_text_states"],
self.scheduler.freqs_cos,
self.scheduler.freqs_sin,
self.scheduler.guidance,
)
img, vec = self.transformer_infer.infer(
self.transformer_weights, *pre_infer_out
)
self.scheduler.noise_pred = self.post_infer.infer(
self.post_weight, img, vec, self.scheduler.latents.shape
)
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
class HunyuanPostWeights:
def __init__(self, config):
self.config = config
def load_weights(self, weight_dict):
self.final_layer_linear = MM_WEIGHT_REGISTER['Default-Force-FP32']('final_layer.linear.weight', 'final_layer.linear.bias')
self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER['Default']('final_layer.adaLN_modulation.1.weight', 'final_layer.adaLN_modulation.1.bias')
self.weight_list = [
self.final_layer_linear,
self.final_layer_adaLN_modulation_1,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
class HunyuanPreWeights:
def __init__(self, config):
self.config = config
def load_weights(self, weight_dict):
self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]('img_in.proj.weight', 'img_in.proj.bias', stride=(1, 2, 2))
self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]('txt_in.input_embedder.weight', 'txt_in.input_embedder.bias')
self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.0.weight', 'txt_in.t_embedder.mlp.0.bias')
self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.2.weight', 'txt_in.t_embedder.mlp.2.bias')
self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_1.weight', 'txt_in.c_embedder.linear_1.bias')
self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_2.weight', 'txt_in.c_embedder.linear_2.bias')
self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm1.weight', 'txt_in.individual_token_refiner.blocks.0.norm1.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias')
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias')
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm2.weight', 'txt_in.individual_token_refiner.blocks.0.norm2.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias')
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias')
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias')
self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm1.weight', 'txt_in.individual_token_refiner.blocks.1.norm1.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias')
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias')
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm2.weight', 'txt_in.individual_token_refiner.blocks.1.norm2.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias')
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias')
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias')
self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.0.weight', 'time_in.mlp.0.bias')
self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.2.weight', 'time_in.mlp.2.bias')
self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.in_layer.weight', 'vector_in.in_layer.bias')
self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.out_layer.weight', 'vector_in.out_layer.bias')
self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.0.weight', 'guidance_in.mlp.0.bias')
self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.2.weight', 'guidance_in.mlp.2.bias')
self.weight_list = [
self.img_in_proj,
self.txt_in_input_embedder,
self.txt_in_t_embedder_mlp_0,
self.txt_in_t_embedder_mlp_2,
self.txt_in_c_embedder_linear_1,
self.txt_in_c_embedder_linear_2,
self.txt_in_individual_token_refiner_blocks_0_norm1,
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj,
self.txt_in_individual_token_refiner_blocks_0_norm2,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1,
self.txt_in_individual_token_refiner_blocks_1_norm1,
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj,
self.txt_in_individual_token_refiner_blocks_1_norm2,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1,
self.time_in_mlp_0,
self.time_in_mlp_2,
self.vector_in_in_layer,
self.vector_in_out_layer,
self.guidance_in_mlp_0,
self.guidance_in_mlp_2,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cuda()
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.norm.rms_norm_weight import RMS_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
class HunyuanTransformerWeights:
def __init__(self, config):
self.config = config
self.init()
def init(self):
self.double_blocks_num = 20
self.single_blocks_num = 40
def load_weights(self, weight_dict):
self.double_blocks_weights = [HunyuanTransformerDoubleBlock(i, self.config) for i in range(self.double_blocks_num)]
self.single_blocks_weights = [HunyuanTransformerSingleBlock(i, self.config) for i in range(self.single_blocks_num)]
for double_block in self.double_blocks_weights:
double_block.load_weights(weight_dict)
for single_block in self.single_blocks_weights:
single_block.load_weights(weight_dict)
def to_cpu(self):
for double_block in self.double_blocks_weights:
double_block.to_cpu()
for single_block in self.single_blocks_weights:
single_block.to_cpu()
def to_cuda(self):
for double_block in self.double_blocks_weights:
double_block.to_cuda()
for single_block in self.single_blocks_weights:
single_block.to_cuda()
class HunyuanTransformerDoubleBlock:
def __init__(self, block_index, config):
self.block_index = block_index
self.config = config
self.weight_list = []
def load_weights(self, weight_dict):
if self.config['do_mm_calib']:
mm_type = 'Calib'
else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default'
self.img_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mod.linear.weight', f'double_blocks.{self.block_index}.img_mod.linear.bias')
self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_qkv.weight', f'double_blocks.{self.block_index}.img_attn_qkv.bias')
self.img_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_q_norm.weight', eps=1e-6)
self.img_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_k_norm.weight', eps=1e-6)
self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_proj.weight', f'double_blocks.{self.block_index}.img_attn_proj.bias')
self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc1.weight', f'double_blocks.{self.block_index}.img_mlp.fc1.bias')
self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc2.weight', f'double_blocks.{self.block_index}.img_mlp.fc2.bias')
self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mod.linear.weight', f'double_blocks.{self.block_index}.txt_mod.linear.bias')
self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_qkv.weight', f'double_blocks.{self.block_index}.txt_attn_qkv.bias')
self.txt_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_q_norm.weight', eps=1e-6)
self.txt_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_k_norm.weight', eps=1e-6)
self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_proj.weight', f'double_blocks.{self.block_index}.txt_attn_proj.bias')
self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc1.weight', f'double_blocks.{self.block_index}.txt_mlp.fc1.bias')
self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc2.weight', f'double_blocks.{self.block_index}.txt_mlp.fc2.bias')
self.weight_list = [
self.img_mod,
self.img_attn_qkv,
self.img_attn_q_norm,
self.img_attn_k_norm,
self.img_attn_proj,
self.img_mlp_fc1,
self.img_mlp_fc2,
self.txt_mod,
self.txt_attn_qkv,
self.txt_attn_q_norm,
self.txt_attn_k_norm,
self.txt_attn_proj,
self.txt_mlp_fc1,
self.txt_mlp_fc2,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.to_cuda()
class HunyuanTransformerSingleBlock:
def __init__(self, block_index, config):
self.block_index = block_index
self.config = config
self.weight_list = []
def load_weights(self, weight_dict):
if self.config['do_mm_calib']:
mm_type = 'Calib'
else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default'
self.linear1 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear1.weight', f'single_blocks.{self.block_index}.linear1.bias')
self.linear2 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear2.weight', f'single_blocks.{self.block_index}.linear2.bias')
self.q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.q_norm.weight', eps=1e-6)
self.k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.k_norm.weight', eps=1e-6)
self.modulation = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.modulation.linear.weight', f'single_blocks.{self.block_index}.modulation.linear.bias')
self.weight_list = [
self.linear1,
self.linear2,
self.q_norm,
self.k_norm,
self.modulation,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.to_cuda()
import numpy as np
from ..transformer_infer import WanTransformerInfer
from lightx2v.attentions import attention
class WanTransformerInferFeatureCaching(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
modulated_inp = embed0 if self.scheduler.use_ret_steps else embed
# teacache
if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True
if (
self.scheduler.cnt < self.scheduler.ret_steps
or self.scheduler.cnt >= self.scheduler.cutoff_steps
):
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
(
(modulated_inp - self.scheduler.previous_e0_even).abs().mean()
/ self.scheduler.previous_e0_even.abs().mean()
)
.cpu()
.item()
)
if (
self.scheduler.accumulated_rel_l1_distance_even
< self.scheduler.teacache_thresh
):
should_calc_even = False
else:
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
self.scheduler.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.scheduler.is_even = False
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item())
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
self.scheduler.previous_e0_odd = modulated_inp.clone()
if self.scheduler.is_even:
if not should_calc_even:
x += self.scheduler.previous_residual_even
else:
ori_x = x.clone()
x = super().infer(
weights,
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.scheduler.previous_residual_even = x - ori_x
else:
if not should_calc_odd:
x += self.scheduler.previous_residual_odd
else:
ori_x = x.clone()
x = super().infer(
weights,
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.scheduler.previous_residual_odd = x - ori_x
return x
import math
import torch
import torch.cuda.amp as amp
class WanPostInfer:
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1)
norm_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = torch.addmm(weights.head_bias, out, weights.head_weight.t())
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d
import torch.cuda.amp as amp
class WanPreInfer:
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (
config["dim"] // config["num_heads"]
) % 2 == 0
d = config["dim"] // config["num_heads"]
self.task = config['task']
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None):
if self.task == 'i2v':
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [weights.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
)
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat(
[
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
]
)
embed = sinusoidal_embedding_1d(self.freq_dim, t)
embed = torch.addmm(
weights.time_embedding_0_bias,
embed,
weights.time_embedding_0_weight.t(),
)
embed = torch.nn.functional.silu(embed)
embed = torch.addmm(
weights.time_embedding_2_bias,
embed,
weights.time_embedding_2_weight.t(),
)
embed0 = torch.nn.functional.silu(embed)
embed0 = torch.addmm(
weights.time_projection_1_bias,
embed0,
weights.time_projection_1_weight.t(),
).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack(
[
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]
)
out = torch.addmm(
weights.text_embedding_0_bias,
stacked.squeeze(0),
weights.text_embedding_0_weight.t(),
)
out = torch.nn.functional.gelu(out, approximate="tanh")
context = torch.addmm(
weights.text_embedding_2_bias,
out,
weights.text_embedding_2_weight.t(),
)
if self.task == 'i2v':
context_clip = torch.nn.functional.layer_norm(
clip_fea,
normalized_shape=(clip_fea.shape[1],),
weight=weights.proj_0_weight,
bias=weights.proj_0_bias,
eps=1e-5,
)
context_clip = torch.addmm(
weights.proj_1_bias,
context_clip,
weights.proj_1_weight.t(),
)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = torch.addmm(
weights.proj_3_bias,
context_clip,
weights.proj_3_weight.t(),
)
context_clip = torch.nn.functional.layer_norm(
context_clip,
normalized_shape=(context_clip.shape[1],),
weight=weights.proj_4_weight,
bias=weights.proj_4_bias,
eps=1e-5,
)
context = torch.concat([context_clip, context], dim=0)
return (
embed,
grid_sizes,
(x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context),
)
import torch
from .utils import compute_freqs, apply_rotary_emb, rms_norm
from lightx2v.attentions import attention
class WanTransformerInfer:
def __init__(self, config):
self.config = config
self.task = config['task']
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"]
self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1))
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def _calculate_q_k_len(self, q, k, k_lens):
lq, nq, c1 = q.size()
lk, nk, c1_k = k.size()
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens = torch.tensor([lq], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = (
torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
)
cu_seqlens_k = (
torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
)
return cu_seqlens_q, cu_seqlens_k, lq, lk
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for i in range(self.blocks_num):
x = self.infer_block(
weights.blocks_weights[i],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
return x
def infer_block(
self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context
):
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = rms_norm(
weights.self_attn_q.apply(norm1_out), weights.self_attn_norm_q_weight, 1e-6
).view(s, n, d)
k = rms_norm(
weights.self_attn_k.apply(norm1_out), weights.self_attn_norm_k_weight, 1e-6
).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens)
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
y = weights.self_attn_o.apply(attn_out)
x = x + y * embed0[2].squeeze(0)
norm3_out = torch.nn.functional.layer_norm(
x,
normalized_shape=(x.shape[1],),
weight=weights.norm3_weight,
bias=weights.norm3_bias,
eps=1e-6,
)
if self.task == 'i2v':
context_img = context[:257]
context = context[257:]
n, d = self.num_heads, self.head_dim
q = rms_norm(
weights.cross_attn_q.apply(norm3_out), weights.cross_attn_norm_q_weight, 1e-6
).view(-1, n, d)
k = rms_norm(
weights.cross_attn_k.apply(context), weights.cross_attn_norm_k_weight, 1e-6
).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d)
if self.task == 'i2v':
k_img = rms_norm(
weights.cross_attn_k_img.apply(context_img), weights.cross_attn_norm_k_img_weight, 1e-6
).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device)
)
img_attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)
)
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
)
attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out
norm2_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0)
return x
import torch
import sgl_kernel
import torch.cuda.amp as amp
def rms_norm(x, weight, eps):
x = x.contiguous()
orig_shape = x.shape
x = x.view(-1, orig_shape[-1])
x = sgl_kernel.rmsnorm(x, weight, eps).view(orig_shape)
return x
def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def apply_rotary_emb(x, freqs_i):
n = x.size(1)
seq_len = freqs_i.size(0)
x_i = torch.view_as_complex(
x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
)
# Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]).to(torch.bfloat16)
return x_i
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half))
)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(torch.bfloat16)
return x
import os
import torch
import time
import glob
from lightx2v.text2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.text2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.text2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.text2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.text2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.text2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferFeatureCaching
from safetensors import safe_open
class WanModel:
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config):
self.model_path = model_path
self.config = config
self._init_infer_class()
self._init_weights()
self._init_infer()
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferFeatureCaching
else:
raise NotImplementedError(
f"Unsupported feature_caching type: {self.config['feature_caching']}"
)
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {
key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()
}
return tensor_dict
def _load_ckpt(self):
safetensors_pattern = os.path.join(self.model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(
f"No .safetensors files found in directory: {self.model_path}"
)
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
# init weights
self.pre_weight = self.pre_weight_class(self.config)
self.post_weight = self.post_weight_class()
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
self.pre_weight.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler)
@torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(
self.transformer_weights, grid_sizes, embed, *pre_infer_out
)
noise_pred_cond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
self.pre_weight,
[self.scheduler.latents],
timestep,
text_encoders_output["context_null"],
self.scheduler.seq_len,
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(
self.transformer_weights, grid_sizes, embed, *pre_infer_out
)
noise_pred_uncond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (
noise_pred_cond - noise_pred_uncond
)
class WanPostWeights:
def __init__(self):
pass
def load_weights(self, weight_dict):
head_layers = {"head": ["head.weight", "head.bias", "modulation"]}
for param_name, param_keys in head_layers.items():
for key in param_keys:
weight_path = f"{param_name}.{key}"
key = key.split('.')
setattr(self, f"{param_name}_{key[-1]}", weight_dict[weight_path])
\ No newline at end of file
import torch
class WanPreWeights:
def __init__(self, config):
self.in_dim = config["in_dim"]
self.dim = config["dim"]
self.patch_size = (1, 2, 2)
def load_weights(self, weight_dict):
layers = {
"text_embedding": {"0": ["weight", "bias"], "2": ["weight", "bias"]},
"time_embedding": {"0": ["weight", "bias"], "2": ["weight", "bias"]},
"time_projection": {"1": ["weight", "bias"]},
}
self.patch_embedding = (
torch.nn.Conv3d(
self.in_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
.to(torch.bfloat16)
.cuda()
)
self.patch_embedding.weight.data.copy_(weight_dict["patch_embedding.weight"])
self.patch_embedding.bias.data.copy_(weight_dict["patch_embedding.bias"])
for module_name, sub_layers in layers.items():
for param_name, param_keys in sub_layers.items():
for key in param_keys:
weight_path = f"{module_name}.{param_name}.{key}"
setattr(
self,
f"{module_name}_{param_name}_{key}",
weight_dict[weight_path],
)
if 'img_emb.proj.0.weight' in weight_dict.keys():
MLP_layers = {
"proj_0_weight": "proj.0.weight",
"proj_0_bias": "proj.0.bias",
"proj_1_weight": "proj.1.weight",
"proj_1_bias": "proj.1.bias",
"proj_3_weight": "proj.3.weight",
"proj_3_bias": "proj.3.bias",
"proj_4_weight": "proj.4.weight",
"proj_4_bias": "proj.4.bias",
}
for layer_name, weight_keys in MLP_layers.items():
weight_path = f"img_emb.{weight_keys}"
setattr(self, layer_name, weight_dict[weight_path])
\ No newline at end of file
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
class WanTransformerWeights:
def __init__(self, config):
self.blocks_num = config["num_layers"]
self.task = config['task']
if config['do_mm_calib']:
self.mm_type = 'Calib'
else:
self.mm_type = config['mm_config'].get('mm_type', 'Default') if config['mm_config'] else 'Default'
def load_weights(self, weight_dict):
self.blocks_weights = [
WanTransformerAttentionBlock(i, self.task, self.mm_type) for i in range(self.blocks_num)
]
for block in self.blocks_weights:
block.load_weights(weight_dict)
class WanTransformerAttentionBlock:
def __init__(self, block_index, task, mm_type):
self.block_index = block_index
self.mm_type = mm_type
self.task = task
def load_weights(self, weight_dict):
if self.task == 't2v':
layers = {
"self_attn_q": ["self_attn.q.weight", "self_attn.q.bias"],
"self_attn_k": ["self_attn.k.weight", "self_attn.k.bias"],
"self_attn_v": ["self_attn.v.weight", "self_attn.v.bias"],
"self_attn_o": ["self_attn.o.weight", "self_attn.o.bias"],
"self_attn_norm_q_weight": "self_attn.norm_q.weight",
"self_attn_norm_k_weight": "self_attn.norm_k.weight",
"norm3_weight": "norm3.weight",
"norm3_bias": "norm3.bias",
"cross_attn_q": ["cross_attn.q.weight", "cross_attn.q.bias"],
"cross_attn_k": ["cross_attn.k.weight", "cross_attn.k.bias"],
"cross_attn_v": ["cross_attn.v.weight", "cross_attn.v.bias"],
"cross_attn_o": ["cross_attn.o.weight", "cross_attn.o.bias"],
"cross_attn_norm_q_weight": "cross_attn.norm_q.weight",
"cross_attn_norm_k_weight": "cross_attn.norm_k.weight",
"ffn_0": ["ffn.0.weight", "ffn.0.bias"],
"ffn_2": ["ffn.2.weight", "ffn.2.bias"],
"modulation": "modulation",
}
elif self.task == 'i2v':
layers = {
"self_attn_q": ["self_attn.q.weight", "self_attn.q.bias"],
"self_attn_k": ["self_attn.k.weight", "self_attn.k.bias"],
"self_attn_v": ["self_attn.v.weight", "self_attn.v.bias"],
"self_attn_o": ["self_attn.o.weight", "self_attn.o.bias"],
"self_attn_norm_q_weight": "self_attn.norm_q.weight",
"self_attn_norm_k_weight": "self_attn.norm_k.weight",
"norm3_weight": "norm3.weight",
"norm3_bias": "norm3.bias",
"cross_attn_q": ["cross_attn.q.weight", "cross_attn.q.bias"],
"cross_attn_k": ["cross_attn.k.weight", "cross_attn.k.bias"],
"cross_attn_v": ["cross_attn.v.weight", "cross_attn.v.bias"],
"cross_attn_o": ["cross_attn.o.weight", "cross_attn.o.bias"],
"cross_attn_norm_q_weight": "cross_attn.norm_q.weight",
"cross_attn_norm_k_weight": "cross_attn.norm_k.weight",
"cross_attn_k_img": ["cross_attn.k_img.weight", "cross_attn.k_img.bias"],
"cross_attn_v_img": ["cross_attn.v_img.weight", "cross_attn.v_img.bias"],
"cross_attn_norm_k_img_weight": "cross_attn.norm_k_img.weight",
"ffn_0": ["ffn.0.weight", "ffn.0.bias"],
"ffn_2": ["ffn.2.weight", "ffn.2.bias"],
"modulation": "modulation",
}
for layer_name, weight_keys in layers.items():
if isinstance(weight_keys, list):
weight_key, bias_key = weight_keys
weight_path = f"blocks.{self.block_index}.{weight_key}"
bias_path = f"blocks.{self.block_index}.{bias_key}"
setattr(self, layer_name, MM_WEIGHT_REGISTER[self.mm_type](weight_path, bias_path))
getattr(self, layer_name).load(weight_dict)
else:
weight_path = f"blocks.{self.block_index}.{weight_keys}"
setattr(self, layer_name, weight_dict[weight_path])
\ No newline at end of file
import torch
from ..scheduler import HunyuanScheduler
def cache_init(num_steps, model_kwargs=None):
'''
Initialization for cache.
'''
cache_dic = {}
cache = {}
cache_index = {}
cache[-1]={}
cache_index[-1]={}
cache_index['layer_index']={}
cache_dic['attn_map'] = {}
cache_dic['attn_map'][-1] = {}
cache_dic['attn_map'][-1]['double_stream'] = {}
cache_dic['attn_map'][-1]['single_stream'] = {}
cache_dic['k-norm'] = {}
cache_dic['k-norm'][-1] = {}
cache_dic['k-norm'][-1]['double_stream'] = {}
cache_dic['k-norm'][-1]['single_stream'] = {}
cache_dic['v-norm'] = {}
cache_dic['v-norm'][-1] = {}
cache_dic['v-norm'][-1]['double_stream'] = {}
cache_dic['v-norm'][-1]['single_stream'] = {}
cache_dic['cross_attn_map'] = {}
cache_dic['cross_attn_map'][-1] = {}
cache[-1]['double_stream']={}
cache[-1]['single_stream']={}
cache_dic['cache_counter'] = 0
for j in range(20):
cache[-1]['double_stream'][j] = {}
cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['total'] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['img_mlp'] = {}
cache_dic['k-norm'][-1]['double_stream'][j] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['img_mlp'] = {}
cache_dic['v-norm'][-1]['double_stream'][j] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['img_mlp'] = {}
for j in range(40):
cache[-1]['single_stream'][j] = {}
cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j]['total'] = {}
cache_dic['k-norm'][-1]['single_stream'][j] = {}
cache_dic['k-norm'][-1]['single_stream'][j]['total'] = {}
cache_dic['v-norm'][-1]['single_stream'][j] = {}
cache_dic['v-norm'][-1]['single_stream'][j]['total'] = {}
cache_dic['taylor_cache'] = False
cache_dic['duca'] = False
cache_dic['test_FLOPs'] = False
mode = 'Taylor'
if mode == 'original':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 1
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
elif mode == 'ToCa':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.10
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
cache_dic['duca'] = False
elif mode == 'DuCa':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.10
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
cache_dic['duca'] = True
elif mode == 'Taylor':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 5
cache_dic['max_order'] = 1
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['taylor_cache'] = True
cache_dic['first_enhance'] = 1
current = {}
current['num_steps'] = num_steps
current['activated_steps'] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic['fresh_ratio'] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic['cal_threshold'] = threshold
#return threshold
def cal_type(cache_dic, current):
'''
Determine calculation type for this step
'''
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
# FORA:Uniform
first_step = (current['step'] == 0)
else:
# ToCa: First enhanced
first_step = (current['step'] < cache_dic['first_enhance'])
#first_step = (current['step'] <= 3)
force_fresh = cache_dic['force_fresh']
if not first_step:
fresh_interval = cache_dic['cal_threshold']
else:
fresh_interval = cache_dic['fresh_threshold']
if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
current['type'] = 'full'
cache_dic['cache_counter'] = 0
current['activated_steps'].append(current['step'])
#current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif (cache_dic['taylor_cache']):
cache_dic['cache_counter'] += 1
current['type'] = 'taylor_cache'
else:
cache_dic['cache_counter'] += 1
if (cache_dic['duca']):
if (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current['type'] = 'ToCa'
# 'cache_noise' 'ToCa' 'FORA'
else:
current['type'] = 'aggressive'
else:
current['type'] = 'ToCa'
#if current['step'] < 25:
# current['type'] = 'FORA'
#else:
# current['type'] = 'aggressive'
######################################################################
#if (current['step'] in [3,2,1,0]):
# current['type'] = 'full'
class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
def __init__(self, args):
super().__init__(args)
self.cache_dic, self.current = cache_init(self.infer_steps)
def step_pre(self, step_index):
super().step_pre(step_index)
self.current['step'] = step_index
cal_type(self.cache_dic, self.current)
import torch
from diffusers.utils.torch_utils import randn_tensor
from typing import Union, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from lightx2v.text2v.models.schedulers.scheduler import BaseScheduler
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to(
dtype=torch.bfloat16, device=device
)
return timesteps, sigmas
class HunyuanScheduler(BaseScheduler):
def __init__(self, args):
super().__init__(args)
self.infer_steps = self.args.infer_steps
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device('cuda'))
assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator('cuda').manual_seed(seed) for seed in [42]]
self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device('cuda')) * 1000.0
def step_post(self):
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
prev_sample = sample + self.noise_pred.to(torch.float32) * dt
self.latents = prev_sample
def prepare_latents(self, shape, dtype):
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device('cuda'), dtype=dtype)
def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3
ndim = 5 - 2
# 884
vae = "884-16c-hy"
patch_size = [1, 2, 2]
hidden_size = 3072
heads_num = 24
rope_theta = 256
rope_dim_list = [16, 56, 56]
if "884" in vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
else:
latents_size = [video_length, height // 8, width // 8]
if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // patch_size for s in latents_size]
elif isinstance(patch_size, list):
assert all(
s % patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // patch_size[idx] for idx, s in enumerate(latents_size)
]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))
import torch
class BaseScheduler():
def __init__(self, args):
self.args = args
self.step_index = 0
self.latents = None
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
import torch
from ..scheduler import WanScheduler
class WanSchedulerFeatureCaching(WanScheduler):
def __init__(self, args):
super().__init__(args)
self.cnt = 0
self.num_steps = self.args.infer_steps * 2
self.teacache_thresh = self.args.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.use_ret_steps = self.args.use_ret_steps
if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
else:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
]
if self.args.target_width == 720 or self.args.target_height == 720:
self.coefficients = [
-114.36346466,
65.26524496,
-18.82220707,
4.91518089,
-0.23412683,
]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment