Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
...@@ -7,6 +7,7 @@ from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreIn ...@@ -7,6 +7,7 @@ from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreIn
from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching
# from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan # from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_hunyuan from lightx2v.attentions.distributed.ulysses.wrap import parallelize_hunyuan
...@@ -23,18 +24,18 @@ class HunyuanModel: ...@@ -23,18 +24,18 @@ class HunyuanModel:
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
if self.config['parallel_attn']: if self.config["parallel_attn"]:
parallelize_hunyuan(self) parallelize_hunyuan(self)
if self.config['cpu_offload']: if self.config["cpu_offload"]:
self.to_cpu() self.to_cpu()
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer self.post_infer_class = HunyuanPostInfer
if self.config['feature_caching'] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer self.transformer_infer_class = HunyuanTransformerInfer
elif self.config['feature_caching'] == "TaylorSeer": elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferFeatureCaching self.transformer_infer_class = HunyuanTransformerInferFeatureCaching
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -87,9 +88,5 @@ class HunyuanModel: ...@@ -87,9 +88,5 @@ class HunyuanModel:
self.scheduler.freqs_sin, self.scheduler.freqs_sin,
self.scheduler.guidance, self.scheduler.guidance,
) )
img, vec = self.transformer_infer.infer( img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.transformer_weights, *pre_infer_out self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
)
self.scheduler.noise_pred = self.post_infer.infer(
self.post_weight, img, vec, self.scheduler.latents.shape
)
...@@ -7,8 +7,8 @@ class HunyuanPostWeights: ...@@ -7,8 +7,8 @@ class HunyuanPostWeights:
self.config = config self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.final_layer_linear = MM_WEIGHT_REGISTER['Default-Force-FP32']('final_layer.linear.weight', 'final_layer.linear.bias') self.final_layer_linear = MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias")
self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER['Default']('final_layer.adaLN_modulation.1.weight', 'final_layer.adaLN_modulation.1.bias') self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias")
self.weight_list = [ self.weight_list = [
self.final_layer_linear, self.final_layer_linear,
...@@ -17,7 +17,7 @@ class HunyuanPostWeights: ...@@ -17,7 +17,7 @@ class HunyuanPostWeights:
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -9,46 +9,72 @@ class HunyuanPreWeights: ...@@ -9,46 +9,72 @@ class HunyuanPreWeights:
self.config = config self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]('img_in.proj.weight', 'img_in.proj.bias', stride=(1, 2, 2)) self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2))
self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]('txt_in.input_embedder.weight', 'txt_in.input_embedder.bias') self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias")
self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.0.weight', 'txt_in.t_embedder.mlp.0.bias') self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias")
self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.2.weight', 'txt_in.t_embedder.mlp.2.bias') self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias")
self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_1.weight', 'txt_in.c_embedder.linear_1.bias') self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias")
self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_2.weight', 'txt_in.c_embedder.linear_2.bias') self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias")
self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm1.weight', 'txt_in.individual_token_refiner.blocks.0.norm1.bias', eps=1e-6) self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"](
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias') "txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias') )
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm2.weight', 'txt_in.individual_token_refiner.blocks.0.norm2.bias', eps=1e-6) self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias') "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias') )
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias') self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
)
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
)
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
)
self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm1.weight', 'txt_in.individual_token_refiner.blocks.1.norm1.bias', eps=1e-6) self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"](
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias') "txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias') )
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm2.weight', 'txt_in.individual_token_refiner.blocks.1.norm2.bias', eps=1e-6) self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias') "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias') )
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias') self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
)
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
)
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
)
self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.0.weight', 'time_in.mlp.0.bias') self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias")
self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.2.weight', 'time_in.mlp.2.bias') self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias")
self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.in_layer.weight', 'vector_in.in_layer.bias') self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias")
self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.out_layer.weight', 'vector_in.out_layer.bias') self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias")
self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.0.weight', 'guidance_in.mlp.0.bias') self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias")
self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.2.weight', 'guidance_in.mlp.2.bias') self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias")
self.weight_list = [ self.weight_list = [
self.img_in_proj, self.img_in_proj,
self.txt_in_input_embedder, self.txt_in_input_embedder,
self.txt_in_t_embedder_mlp_0, self.txt_in_t_embedder_mlp_0,
self.txt_in_t_embedder_mlp_2, self.txt_in_t_embedder_mlp_2,
self.txt_in_c_embedder_linear_1, self.txt_in_c_embedder_linear_1,
self.txt_in_c_embedder_linear_2, self.txt_in_c_embedder_linear_2,
self.txt_in_individual_token_refiner_blocks_0_norm1, self.txt_in_individual_token_refiner_blocks_0_norm1,
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv, self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj, self.txt_in_individual_token_refiner_blocks_0_self_attn_proj,
...@@ -56,7 +82,6 @@ class HunyuanPreWeights: ...@@ -56,7 +82,6 @@ class HunyuanPreWeights:
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1, self.txt_in_individual_token_refiner_blocks_0_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2, self.txt_in_individual_token_refiner_blocks_0_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1, self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1,
self.txt_in_individual_token_refiner_blocks_1_norm1, self.txt_in_individual_token_refiner_blocks_1_norm1,
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv, self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj, self.txt_in_individual_token_refiner_blocks_1_self_attn_proj,
...@@ -64,7 +89,6 @@ class HunyuanPreWeights: ...@@ -64,7 +89,6 @@ class HunyuanPreWeights:
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1, self.txt_in_individual_token_refiner_blocks_1_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2, self.txt_in_individual_token_refiner_blocks_1_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1, self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1,
self.time_in_mlp_0, self.time_in_mlp_0,
self.time_in_mlp_2, self.time_in_mlp_2,
self.vector_in_in_layer, self.vector_in_in_layer,
...@@ -75,7 +99,7 @@ class HunyuanPreWeights: ...@@ -75,7 +99,7 @@ class HunyuanPreWeights:
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate): if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -41,26 +41,26 @@ class HunyuanTransformerDoubleBlock: ...@@ -41,26 +41,26 @@ class HunyuanTransformerDoubleBlock:
self.weight_list = [] self.weight_list = []
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
if self.config['do_mm_calib']: if self.config["do_mm_calib"]:
mm_type = 'Calib' mm_type = "Calib"
else: else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default' mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.img_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mod.linear.weight', f'double_blocks.{self.block_index}.img_mod.linear.bias') self.img_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mod.linear.weight", f"double_blocks.{self.block_index}.img_mod.linear.bias")
self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_qkv.weight', f'double_blocks.{self.block_index}.img_attn_qkv.bias') self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_qkv.weight", f"double_blocks.{self.block_index}.img_attn_qkv.bias")
self.img_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_q_norm.weight', eps=1e-6) self.img_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_q_norm.weight", eps=1e-6)
self.img_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_k_norm.weight', eps=1e-6) self.img_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_k_norm.weight", eps=1e-6)
self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_proj.weight', f'double_blocks.{self.block_index}.img_attn_proj.bias') self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_proj.weight", f"double_blocks.{self.block_index}.img_attn_proj.bias")
self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc1.weight', f'double_blocks.{self.block_index}.img_mlp.fc1.bias') self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc1.weight", f"double_blocks.{self.block_index}.img_mlp.fc1.bias")
self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc2.weight', f'double_blocks.{self.block_index}.img_mlp.fc2.bias') self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc2.weight", f"double_blocks.{self.block_index}.img_mlp.fc2.bias")
self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mod.linear.weight', f'double_blocks.{self.block_index}.txt_mod.linear.bias') self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mod.linear.weight", f"double_blocks.{self.block_index}.txt_mod.linear.bias")
self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_qkv.weight', f'double_blocks.{self.block_index}.txt_attn_qkv.bias') self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_qkv.weight", f"double_blocks.{self.block_index}.txt_attn_qkv.bias")
self.txt_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_q_norm.weight', eps=1e-6) self.txt_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_q_norm.weight", eps=1e-6)
self.txt_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_k_norm.weight', eps=1e-6) self.txt_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_k_norm.weight", eps=1e-6)
self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_proj.weight', f'double_blocks.{self.block_index}.txt_attn_proj.bias') self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_proj.weight", f"double_blocks.{self.block_index}.txt_attn_proj.bias")
self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc1.weight', f'double_blocks.{self.block_index}.txt_mlp.fc1.bias') self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias")
self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc2.weight', f'double_blocks.{self.block_index}.txt_mlp.fc2.bias') self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias")
self.weight_list = [ self.weight_list = [
self.img_mod, self.img_mod,
...@@ -81,7 +81,7 @@ class HunyuanTransformerDoubleBlock: ...@@ -81,7 +81,7 @@ class HunyuanTransformerDoubleBlock:
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
...@@ -102,16 +102,16 @@ class HunyuanTransformerSingleBlock: ...@@ -102,16 +102,16 @@ class HunyuanTransformerSingleBlock:
self.weight_list = [] self.weight_list = []
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
if self.config['do_mm_calib']: if self.config["do_mm_calib"]:
mm_type = 'Calib' mm_type = "Calib"
else: else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default' mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.linear1 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear1.weight', f'single_blocks.{self.block_index}.linear1.bias') self.linear1 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear1.weight", f"single_blocks.{self.block_index}.linear1.bias")
self.linear2 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear2.weight', f'single_blocks.{self.block_index}.linear2.bias') self.linear2 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear2.weight", f"single_blocks.{self.block_index}.linear2.bias")
self.q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.q_norm.weight', eps=1e-6) self.q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6)
self.k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.k_norm.weight', eps=1e-6) self.k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6)
self.modulation = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.modulation.linear.weight', f'single_blocks.{self.block_index}.modulation.linear.bias') self.modulation = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias")
self.weight_list = [ self.weight_list = [
self.linear1, self.linear1,
...@@ -123,7 +123,7 @@ class HunyuanTransformerSingleBlock: ...@@ -123,7 +123,7 @@ class HunyuanTransformerSingleBlock:
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -13,26 +13,15 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer): ...@@ -13,26 +13,15 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
# teacache # teacache
if self.scheduler.cnt % 2 == 0: # even -> conditon if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True self.scheduler.is_even = True
if ( if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
self.scheduler.cnt < self.scheduler.ret_steps
or self.scheduler.cnt >= self.scheduler.cutoff_steps
):
should_calc_even = True should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0 self.scheduler.accumulated_rel_l1_distance_even = 0
else: else:
rescale_func = np.poly1d(self.scheduler.coefficients) rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func( self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
( ((modulated_inp - self.scheduler.previous_e0_even).abs().mean() / self.scheduler.previous_e0_even.abs().mean()).cpu().item()
(modulated_inp - self.scheduler.previous_e0_even).abs().mean()
/ self.scheduler.previous_e0_even.abs().mean()
)
.cpu()
.item()
) )
if ( if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh:
self.scheduler.accumulated_rel_l1_distance_even
< self.scheduler.teacache_thresh
):
should_calc_even = False should_calc_even = False
else: else:
should_calc_even = True should_calc_even = True
...@@ -46,7 +35,9 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer): ...@@ -46,7 +35,9 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
self.scheduler.accumulated_rel_l1_distance_odd = 0 self.scheduler.accumulated_rel_l1_distance_odd = 0
else: else:
rescale_func = np.poly1d(self.scheduler.coefficients) rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item()) self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh: if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False should_calc_odd = False
else: else:
......
...@@ -10,9 +10,7 @@ class WanPostInfer: ...@@ -10,9 +10,7 @@ class WanPostInfer:
def infer(self, weights, x, e, grid_sizes): def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1)
norm_out = torch.nn.functional.layer_norm( norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
x, (x.shape[1],), None, None, 1e-6
).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0) out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out) x = weights.head.apply(out)
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
......
...@@ -6,12 +6,10 @@ import torch.cuda.amp as amp ...@@ -6,12 +6,10 @@ import torch.cuda.amp as amp
class WanPreInfer: class WanPreInfer:
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and ( assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
config["dim"] // config["num_heads"]
) % 2 == 0
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.task = config['task'] self.task = config["task"]
self.freqs = torch.cat( self.freqs = torch.cat(
[ [
rope_params(1024, d - 4 * (d // 6)), rope_params(1024, d - 4 * (d // 6)),
...@@ -25,24 +23,16 @@ class WanPreInfer: ...@@ -25,24 +23,16 @@ class WanPreInfer:
self.text_len = config["text_len"] self.text_len = config["text_len"]
def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None): def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None):
if self.task == "i2v":
if self.task == 'i2v':
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings # embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x] x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack( grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
)
x = [u.flatten(2).transpose(1, 2) for u in x] x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda() seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len assert seq_lens.max() <= seq_len
x = torch.cat( x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
[
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
]
)
embed = sinusoidal_embedding_1d(self.freq_dim, t) embed = sinusoidal_embedding_1d(self.freq_dim, t)
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
...@@ -53,17 +43,12 @@ class WanPreInfer: ...@@ -53,17 +43,12 @@ class WanPreInfer:
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings # text embeddings
stacked = torch.stack( stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
[
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]
)
out = weights.text_embedding_0.apply(stacked.squeeze(0)) out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out) context = weights.text_embedding_2.apply(out)
if self.task == 'i2v': if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip) context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none") context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
......
...@@ -6,7 +6,7 @@ from lightx2v.attentions import attention ...@@ -6,7 +6,7 @@ from lightx2v.attentions import attention
class WanTransformerInfer: class WanTransformerInfer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.task = config['task'] self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.num_heads = config["num_heads"] self.num_heads = config["num_heads"]
...@@ -24,14 +24,8 @@ class WanTransformerInfer: ...@@ -24,14 +24,8 @@ class WanTransformerInfer:
q_lens = torch.tensor([lq], dtype=torch.int32, device=q.device) q_lens = torch.tensor([lq], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values # We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = ( cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
torch.cat([q_lens.new_zeros([1]), q_lens]) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
.cumsum(0, dtype=torch.int32)
)
cu_seqlens_k = (
torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
)
return cu_seqlens_q, cu_seqlens_k, lq, lk return cu_seqlens_q, cu_seqlens_k, lq, lk
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -49,14 +43,10 @@ class WanTransformerInfer: ...@@ -49,14 +43,10 @@ class WanTransformerInfer:
return x return x
def infer_block( def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context
):
embed0 = (weights.modulation + embed0).chunk(6, dim=1) embed0 = (weights.modulation + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm( norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x, (x.shape[1],), None, None, 1e-6
)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
...@@ -92,7 +82,7 @@ class WanTransformerInfer: ...@@ -92,7 +82,7 @@ class WanTransformerInfer:
k=k, k=k,
v=v, v=v,
img_qkv_len=q.shape[0], img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q cu_seqlens_qkv=cu_seqlens_q,
# cu_seqlens_qkv=cu_seqlens_qkv, # cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv, # max_seqlen_qkv=max_seqlen_qkv,
) )
...@@ -102,7 +92,7 @@ class WanTransformerInfer: ...@@ -102,7 +92,7 @@ class WanTransformerInfer:
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == 'i2v': if self.task == "i2v":
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
...@@ -111,13 +101,11 @@ class WanTransformerInfer: ...@@ -111,13 +101,11 @@ class WanTransformerInfer:
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d) k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d) v = weights.cross_attn_v.apply(context).view(-1, n, d)
if self.task == 'i2v': if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device))
q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device)
)
img_attn_out = attention( img_attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -130,9 +118,7 @@ class WanTransformerInfer: ...@@ -130,9 +118,7 @@ class WanTransformerInfer:
max_seqlen_kv=lk, max_seqlen_kv=lk,
) )
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)
)
attn_out = attention( attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -147,9 +133,7 @@ class WanTransformerInfer: ...@@ -147,9 +133,7 @@ class WanTransformerInfer:
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out x = x + attn_out
norm2_out = torch.nn.functional.layer_norm( norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x, (x.shape[1],), None, None, 1e-6
)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0)) y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
......
...@@ -23,12 +23,7 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -23,12 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
padding_tensor = torch.ones( padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor return padded_tensor
...@@ -50,8 +45,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs): ...@@ -50,8 +45,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs):
freqs_i = pad_freqs(freqs_i, s * world_size) freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank):((cur_rank + 1) * freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
s_per_rank), :, :]
return freqs_i_rank return freqs_i_rank
...@@ -59,9 +53,7 @@ def apply_rotary_emb(x, freqs_i): ...@@ -59,9 +53,7 @@ def apply_rotary_emb(x, freqs_i):
n = x.size(1) n = x.size(1)
seq_len = freqs_i.size(0) seq_len = freqs_i.size(0)
x_i = torch.view_as_complex( x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
)
# Apply rotary embedding # Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]).to(torch.bfloat16) x_i = torch.cat([x_i, x[seq_len:]]).to(torch.bfloat16)
...@@ -85,8 +77,6 @@ def sinusoidal_embedding_1d(dim, position): ...@@ -85,8 +77,6 @@ def sinusoidal_embedding_1d(dim, position):
position = position.type(torch.float64) position = position.type(torch.float64)
# calculation # calculation
sinusoid = torch.outer( sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
position, torch.pow(10000, -torch.arange(half).to(position).div(half))
)
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(torch.bfloat16) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(torch.bfloat16)
return x return x
...@@ -29,10 +29,10 @@ class WanModel: ...@@ -29,10 +29,10 @@ class WanModel:
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
if config['parallel_attn']: if config["parallel_attn"]:
parallelize_wan(self) parallelize_wan(self)
if self.config['cpu_offload']: if self.config["cpu_offload"]:
self.to_cpu() self.to_cpu()
def _init_infer_class(self): def _init_infer_class(self):
...@@ -43,15 +43,11 @@ class WanModel: ...@@ -43,15 +43,11 @@ class WanModel:
elif self.config["feature_caching"] == "Tea": elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferFeatureCaching self.transformer_infer_class = WanTransformerInferFeatureCaching
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
f"Unsupported feature_caching type: {self.config['feature_caching']}"
)
def _load_safetensor_to_dict(self, file_path): def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
tensor_dict = { tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()
}
return tensor_dict return tensor_dict
def _load_ckpt(self): def _load_ckpt(self):
...@@ -59,9 +55,7 @@ class WanModel: ...@@ -59,9 +55,7 @@ class WanModel:
safetensors_files = glob.glob(safetensors_pattern) safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files: if not safetensors_files:
raise FileNotFoundError( raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
f"No .safetensors files found in directory: {self.model_path}"
)
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path) file_weights = self._load_safetensor_to_dict(file_path)
...@@ -100,7 +94,6 @@ class WanModel: ...@@ -100,7 +94,6 @@ class WanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args): def infer(self, text_encoders_output, image_encoder_output, args):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]]) timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
embed, grid_sizes, pre_infer_out = self.pre_infer.infer( embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
...@@ -112,12 +105,8 @@ class WanModel: ...@@ -112,12 +105,8 @@ class WanModel:
image_encoder_output["clip_encoder_out"], image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]], [image_encoder_output["vae_encode_out"]],
) )
x = self.transformer_infer.infer( x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
self.transformer_weights, grid_sizes, embed, *pre_infer_out noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
)
noise_pred_cond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1 self.scheduler.cnt += 1
...@@ -133,18 +122,12 @@ class WanModel: ...@@ -133,18 +122,12 @@ class WanModel:
image_encoder_output["clip_encoder_out"], image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]], [image_encoder_output["vae_encode_out"]],
) )
x = self.transformer_infer.infer( x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
self.transformer_weights, grid_sizes, embed, *pre_infer_out noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
)
noise_pred_uncond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
if self.config["feature_caching"] == "Tea": if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1 self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps: if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0 self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * ( self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred_cond - noise_pred_uncond
)
...@@ -7,18 +7,14 @@ class WanPostWeights: ...@@ -7,18 +7,14 @@ class WanPostWeights:
self.config = config self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
self.head_modulation = weight_dict["head.modulation"]
self.head = MM_WEIGHT_REGISTER["Default"]('head.head.weight','head.head.bias') self.weight_list = [self.head, self.head_modulation]
self.head_modulation = weight_dict['head.modulation']
self.weight_list = [
self.head,
self.head_modulation
]
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -4,6 +4,7 @@ from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate ...@@ -4,6 +4,7 @@ from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
class WanPreWeights: class WanPreWeights:
def __init__(self, config): def __init__(self, config):
self.in_dim = config["in_dim"] self.in_dim = config["in_dim"]
...@@ -12,18 +13,16 @@ class WanPreWeights: ...@@ -12,18 +13,16 @@ class WanPreWeights:
self.config = config self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size)
self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]('patch_embedding.weight', 'patch_embedding.bias', stride=self.patch_size) self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias")
self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias")
self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]('text_embedding.0.weight', 'text_embedding.0.bias') self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias")
self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]('text_embedding.2.weight', 'text_embedding.2.bias') self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias")
self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]('time_embedding.0.weight', 'time_embedding.0.bias') self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias")
self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]('time_embedding.2.weight', 'time_embedding.2.bias')
self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]('time_projection.1.weight', 'time_projection.1.bias')
self.weight_list = [ self.weight_list = [
self.patch_embedding, self.patch_embedding,
self.text_embedding_0, self.text_embedding_0,
self.text_embedding_2, self.text_embedding_2,
self.time_embedding_0, self.time_embedding_0,
...@@ -31,11 +30,11 @@ class WanPreWeights: ...@@ -31,11 +30,11 @@ class WanPreWeights:
self.time_projection_1, self.time_projection_1,
] ]
if 'img_emb.proj.0.weight' in weight_dict.keys(): if "img_emb.proj.0.weight" in weight_dict.keys():
self.proj_0 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.0.weight', 'img_emb.proj.0.bias', eps=1e-5) self.proj_0 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5)
self.proj_1 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.1.weight', 'img_emb.proj.1.bias') self.proj_1 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias")
self.proj_3 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.3.weight', 'img_emb.proj.3.bias') self.proj_3 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias")
self.proj_4 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.4.weight', 'img_emb.proj.4.bias', eps=1e-5) self.proj_4 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5)
self.weight_list.append(self.proj_0) self.weight_list.append(self.proj_0)
self.weight_list.append(self.proj_1) self.weight_list.append(self.proj_1)
self.weight_list.append(self.proj_3) self.weight_list.append(self.proj_3)
...@@ -43,7 +42,7 @@ class WanPreWeights: ...@@ -43,7 +42,7 @@ class WanPreWeights:
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -7,17 +7,15 @@ from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate ...@@ -7,17 +7,15 @@ from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
class WanTransformerWeights: class WanTransformerWeights:
def __init__(self, config): def __init__(self, config):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config['task'] self.task = config["task"]
self.config = config self.config = config
if config['do_mm_calib']: if config["do_mm_calib"]:
self.mm_type = 'Calib' self.mm_type = "Calib"
else: else:
self.mm_type = config['mm_config'].get('mm_type', 'Default') if config['mm_config'] else 'Default' self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.blocks_weights = [ self.blocks_weights = [WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]
WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)
]
for block in self.blocks_weights: for block in self.blocks_weights:
block.load_weights(weight_dict) block.load_weights(weight_dict)
...@@ -38,25 +36,24 @@ class WanTransformerAttentionBlock: ...@@ -38,25 +36,24 @@ class WanTransformerAttentionBlock:
self.config = config self.config = config
def load_weights(self, weight_dict): def load_weights(self, weight_dict):
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.q.weight',f'blocks.{self.block_index}.self_attn.q.bias') self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")
self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.k.weight',f'blocks.{self.block_index}.self_attn.k.bias') self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.bias")
self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.v.weight',f'blocks.{self.block_index}.self_attn.v.bias') self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.bias")
self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.o.weight',f'blocks.{self.block_index}.self_attn.o.bias') self.self_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight")
self.self_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_q.weight') self.self_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight")
self.self_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_k.weight')
self.norm3 = LN_WEIGHT_REGISTER["Default"](f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.bias", eps=1e-6)
self.norm3 = LN_WEIGHT_REGISTER['Default'](f'blocks.{self.block_index}.norm3.weight',f'blocks.{self.block_index}.norm3.bias',eps = 1e-6) self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.bias")
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.q.weight',f'blocks.{self.block_index}.cross_attn.q.bias') self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.bias")
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k.weight',f'blocks.{self.block_index}.cross_attn.k.bias') self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.bias")
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v.weight',f'blocks.{self.block_index}.cross_attn.v.bias') self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.bias")
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.o.weight',f'blocks.{self.block_index}.cross_attn.o.bias') self.cross_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight")
self.cross_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_q.weight') self.cross_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight")
self.cross_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k.weight')
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.0.weight',f'blocks.{self.block_index}.ffn.0.bias') self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.2.weight',f'blocks.{self.block_index}.ffn.2.bias') self.modulation = weight_dict[f"blocks.{self.block_index}.modulation"]
self.modulation = weight_dict[f'blocks.{self.block_index}.modulation']
self.weight_list = [ self.weight_list = [
self.self_attn_q, self.self_attn_q,
...@@ -77,18 +74,18 @@ class WanTransformerAttentionBlock: ...@@ -77,18 +74,18 @@ class WanTransformerAttentionBlock:
self.modulation, self.modulation,
] ]
if self.task == 'i2v': if self.task == "i2v":
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k_img.weight',f'blocks.{self.block_index}.cross_attn.k_img.bias') self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias")
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v_img.weight',f'blocks.{self.block_index}.cross_attn.v_img.bias') self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias")
# self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight'] # self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k_img.weight') self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight")
self.weight_list.append(self.cross_attn_k_img) self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img) self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img) self.weight_list.append(self.cross_attn_norm_k_img)
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config['mm_config']) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
......
...@@ -3,191 +3,193 @@ from ..scheduler import HunyuanScheduler ...@@ -3,191 +3,193 @@ from ..scheduler import HunyuanScheduler
def cache_init(num_steps, model_kwargs=None): def cache_init(num_steps, model_kwargs=None):
''' """
Initialization for cache. Initialization for cache.
''' """
cache_dic = {} cache_dic = {}
cache = {} cache = {}
cache_index = {} cache_index = {}
cache[-1]={} cache[-1] = {}
cache_index[-1]={} cache_index[-1] = {}
cache_index['layer_index']={} cache_index["layer_index"] = {}
cache_dic['attn_map'] = {} cache_dic["attn_map"] = {}
cache_dic['attn_map'][-1] = {} cache_dic["attn_map"][-1] = {}
cache_dic['attn_map'][-1]['double_stream'] = {} cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic['attn_map'][-1]['single_stream'] = {} cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic['k-norm'] = {} cache_dic["k-norm"] = {}
cache_dic['k-norm'][-1] = {} cache_dic["k-norm"][-1] = {}
cache_dic['k-norm'][-1]['double_stream'] = {} cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic['k-norm'][-1]['single_stream'] = {} cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic['v-norm'] = {} cache_dic["v-norm"] = {}
cache_dic['v-norm'][-1] = {} cache_dic["v-norm"][-1] = {}
cache_dic['v-norm'][-1]['double_stream'] = {} cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic['v-norm'][-1]['single_stream'] = {} cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic['cross_attn_map'] = {} cache_dic["cross_attn_map"] = {}
cache_dic['cross_attn_map'][-1] = {} cache_dic["cross_attn_map"][-1] = {}
cache[-1]['double_stream']={} cache[-1]["double_stream"] = {}
cache[-1]['single_stream']={} cache[-1]["single_stream"] = {}
cache_dic['cache_counter'] = 0 cache_dic["cache_counter"] = 0
for j in range(20): for j in range(20):
cache[-1]['double_stream'][j] = {} cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {} cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j] = {} cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['total'] = {} cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['img_mlp'] = {} cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic['k-norm'][-1]['double_stream'][j] = {} cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['img_mlp'] = {} cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic['v-norm'][-1]['double_stream'][j] = {} cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['img_mlp'] = {} cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40): for j in range(40):
cache[-1]['single_stream'][j] = {} cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {} cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j] = {} cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j]['total'] = {} cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic['k-norm'][-1]['single_stream'][j] = {} cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic['k-norm'][-1]['single_stream'][j]['total'] = {} cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic['v-norm'][-1]['single_stream'][j] = {} cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic['v-norm'][-1]['single_stream'][j]['total'] = {} cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic['taylor_cache'] = False cache_dic["taylor_cache"] = False
cache_dic['duca'] = False cache_dic["duca"] = False
cache_dic['test_FLOPs'] = False cache_dic["test_FLOPs"] = False
mode = 'Taylor' mode = "Taylor"
if mode == 'original': if mode == "original":
cache_dic['cache_type'] = 'random' cache_dic["cache_type"] = "random"
cache_dic['cache_index'] = cache_index cache_dic["cache_index"] = cache_index
cache_dic['cache'] = cache cache_dic["cache"] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic['fresh_ratio'] = 0.0 cache_dic["fresh_ratio"] = 0.0
cache_dic['fresh_threshold'] = 1 cache_dic["fresh_threshold"] = 1
cache_dic['force_fresh'] = 'global' cache_dic["force_fresh"] = "global"
cache_dic['soft_fresh_weight'] = 0.0 cache_dic["soft_fresh_weight"] = 0.0
cache_dic['max_order'] = 0 cache_dic["max_order"] = 0
cache_dic['first_enhance'] = 1 cache_dic["first_enhance"] = 1
elif mode == 'ToCa': elif mode == "ToCa":
cache_dic['cache_type'] = 'random' cache_dic["cache_type"] = "random"
cache_dic['cache_index'] = cache_index cache_dic["cache_index"] = cache_index
cache_dic['cache'] = cache cache_dic["cache"] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic['fresh_ratio'] = 0.10 cache_dic["fresh_ratio"] = 0.10
cache_dic['fresh_threshold'] = 5 cache_dic["fresh_threshold"] = 5
cache_dic['force_fresh'] = 'global' cache_dic["force_fresh"] = "global"
cache_dic['soft_fresh_weight'] = 0.0 cache_dic["soft_fresh_weight"] = 0.0
cache_dic['max_order'] = 0 cache_dic["max_order"] = 0
cache_dic['first_enhance'] = 1 cache_dic["first_enhance"] = 1
cache_dic['duca'] = False cache_dic["duca"] = False
elif mode == 'DuCa': elif mode == "DuCa":
cache_dic['cache_type'] = 'random' cache_dic["cache_type"] = "random"
cache_dic['cache_index'] = cache_index cache_dic["cache_index"] = cache_index
cache_dic['cache'] = cache cache_dic["cache"] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic['fresh_ratio'] = 0.10 cache_dic["fresh_ratio"] = 0.10
cache_dic['fresh_threshold'] = 5 cache_dic["fresh_threshold"] = 5
cache_dic['force_fresh'] = 'global' cache_dic["force_fresh"] = "global"
cache_dic['soft_fresh_weight'] = 0.0 cache_dic["soft_fresh_weight"] = 0.0
cache_dic['max_order'] = 0 cache_dic["max_order"] = 0
cache_dic['first_enhance'] = 1 cache_dic["first_enhance"] = 1
cache_dic['duca'] = True cache_dic["duca"] = True
elif mode == 'Taylor': elif mode == "Taylor":
cache_dic['cache_type'] = 'random' cache_dic["cache_type"] = "random"
cache_dic['cache_index'] = cache_index cache_dic["cache_index"] = cache_index
cache_dic['cache'] = cache cache_dic["cache"] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic['fresh_ratio'] = 0.0 cache_dic["fresh_ratio"] = 0.0
cache_dic['fresh_threshold'] = 5 cache_dic["fresh_threshold"] = 5
cache_dic['max_order'] = 1 cache_dic["max_order"] = 1
cache_dic['force_fresh'] = 'global' cache_dic["force_fresh"] = "global"
cache_dic['soft_fresh_weight'] = 0.0 cache_dic["soft_fresh_weight"] = 0.0
cache_dic['taylor_cache'] = True cache_dic["taylor_cache"] = True
cache_dic['first_enhance'] = 1 cache_dic["first_enhance"] = 1
current = {} current = {}
current['num_steps'] = num_steps current["num_steps"] = num_steps
current['activated_steps'] = [0] current["activated_steps"] = [0]
return cache_dic, current return cache_dic, current
def force_scheduler(cache_dic, current): def force_scheduler(cache_dic, current):
if cache_dic['fresh_ratio'] == 0: if cache_dic["fresh_ratio"] == 0:
# FORA # FORA
linear_step_weight = 0.0 linear_step_weight = 0.0
else: else:
# TokenCache # TokenCache
linear_step_weight = 0.0 linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps']) step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor) threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough. # no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try. # you may have a try.
cache_dic['cal_threshold'] = threshold cache_dic["cal_threshold"] = threshold
#return threshold # return threshold
def cal_type(cache_dic, current): def cal_type(cache_dic, current):
''' """
Determine calculation type for this step Determine calculation type for this step
''' """
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']): if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform # FORA:Uniform
first_step = (current['step'] == 0) first_step = current["step"] == 0
else: else:
# ToCa: First enhanced # ToCa: First enhanced
first_step = (current['step'] < cache_dic['first_enhance']) first_step = current["step"] < cache_dic["first_enhance"]
#first_step = (current['step'] <= 3) # first_step = (current['step'] <= 3)
force_fresh = cache_dic['force_fresh'] force_fresh = cache_dic["force_fresh"]
if not first_step: if not first_step:
fresh_interval = cache_dic['cal_threshold'] fresh_interval = cache_dic["cal_threshold"]
else: else:
fresh_interval = cache_dic['fresh_threshold'] fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ): if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current['type'] = 'full' current["type"] = "full"
cache_dic['cache_counter'] = 0 cache_dic["cache_counter"] = 0
current['activated_steps'].append(current['step']) current["activated_steps"].append(current["step"])
#current['activated_times'].append(current['t']) # current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current) force_scheduler(cache_dic, current)
elif (cache_dic['taylor_cache']): elif cache_dic["taylor_cache"]:
cache_dic['cache_counter'] += 1 cache_dic["cache_counter"] += 1
current['type'] = 'taylor_cache' current["type"] = "taylor_cache"
else: else:
cache_dic['cache_counter'] += 1 cache_dic["cache_counter"] += 1
if (cache_dic['duca']): if cache_dic["duca"]:
if (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current['type'] = 'ToCa' current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA' # 'cache_noise' 'ToCa' 'FORA'
else: else:
current['type'] = 'aggressive' current["type"] = "aggressive"
else: else:
current['type'] = 'ToCa' current["type"] = "ToCa"
#if current['step'] < 25: # if current['step'] < 25:
# current['type'] = 'FORA' # current['type'] = 'FORA'
#else: # else:
# current['type'] = 'aggressive' # current['type'] = 'aggressive'
###################################################################### ######################################################################
#if (current['step'] in [3,2,1,0]): # if (current['step'] in [3,2,1,0]):
# current['type'] = 'full' # current['type'] = 'full'
class HunyuanSchedulerFeatureCaching(HunyuanScheduler): class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
...@@ -197,5 +199,5 @@ class HunyuanSchedulerFeatureCaching(HunyuanScheduler): ...@@ -197,5 +199,5 @@ class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
def step_pre(self, step_index): def step_pre(self, step_index):
super().step_pre(step_index) super().step_pre(step_index)
self.current['step'] = step_index self.current["step"] = step_index
cal_type(self.cache_dic, self.current) cal_type(self.cache_dic, self.current)
...@@ -13,6 +13,7 @@ def _to_tuple(x, dim=2): ...@@ -13,6 +13,7 @@ def _to_tuple(x, dim=2):
else: else:
raise ValueError(f"Expected length {dim} or int, but got {x}") raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_1d_rotary_pos_embed( def get_1d_rotary_pos_embed(
dim: int, dim: int,
pos: Union[torch.FloatTensor, int], pos: Union[torch.FloatTensor, int],
...@@ -49,9 +50,7 @@ def get_1d_rotary_pos_embed( ...@@ -49,9 +50,7 @@ def get_1d_rotary_pos_embed(
if theta_rescale_factor != 1.0: if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2)) theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / ( freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real: if use_real:
...@@ -59,9 +58,7 @@ def get_1d_rotary_pos_embed( ...@@ -59,9 +58,7 @@ def get_1d_rotary_pos_embed(
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
else: else:
freqs_cis = torch.polar( freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis return freqs_cis
...@@ -109,6 +106,7 @@ def get_meshgrid_nd(start, *args, dim=2): ...@@ -109,6 +106,7 @@ def get_meshgrid_nd(start, *args, dim=2):
return grid return grid
def get_nd_rotary_pos_embed( def get_nd_rotary_pos_embed(
rope_dim_list, rope_dim_list,
start, start,
...@@ -137,25 +135,19 @@ def get_nd_rotary_pos_embed( ...@@ -137,25 +135,19 @@ def get_nd_rotary_pos_embed(
pos_embed (torch.Tensor): [HW, D/2] pos_embed (torch.Tensor): [HW, D/2]
""" """
grid = get_meshgrid_nd( grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len( assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list) interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len( assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis # use 1/ndim of dimensions to encode grid_axis
embs = [] embs = []
...@@ -182,9 +174,7 @@ def get_nd_rotary_pos_embed( ...@@ -182,9 +174,7 @@ def get_nd_rotary_pos_embed(
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000): def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1) sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas) sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to( timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device)
dtype=torch.bfloat16, device=device
)
return timesteps, sigmas return timesteps, sigmas
...@@ -193,17 +183,17 @@ class HunyuanScheduler(BaseScheduler): ...@@ -193,17 +183,17 @@ class HunyuanScheduler(BaseScheduler):
super().__init__(args) super().__init__(args)
self.infer_steps = self.args.infer_steps self.infer_steps = self.args.infer_steps
self.shift = 7.0 self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device('cuda')) self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0 self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator('cuda').manual_seed(seed) for seed in [42]] self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]]
self.noise_pred = None self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16) self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16)
self.prepare_guidance() self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width) self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width)
def prepare_guidance(self): def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device('cuda')) * 1000.0 self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
def step_post(self): def step_post(self):
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
...@@ -212,7 +202,7 @@ class HunyuanScheduler(BaseScheduler): ...@@ -212,7 +202,7 @@ class HunyuanScheduler(BaseScheduler):
self.latents = prev_sample self.latents = prev_sample
def prepare_latents(self, shape, dtype): def prepare_latents(self, shape, dtype):
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device('cuda'), dtype=dtype) self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
def prepare_rotary_pos_embedding(self, video_length, height, width): def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3 target_ndim = 3
...@@ -232,22 +222,11 @@ class HunyuanScheduler(BaseScheduler): ...@@ -232,22 +222,11 @@ class HunyuanScheduler(BaseScheduler):
latents_size = [video_length, height // 8, width // 8] latents_size = [video_length, height // 8, width // 8]
if isinstance(patch_size, int): if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), ( assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // patch_size for s in latents_size] rope_sizes = [s // patch_size for s in latents_size]
elif isinstance(patch_size, list): elif isinstance(patch_size, list):
assert all( assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
s % patch_size[idx] == 0 rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // patch_size[idx] for idx, s in enumerate(latents_size)
]
if len(rope_sizes) != target_ndim: if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
...@@ -255,9 +234,7 @@ class HunyuanScheduler(BaseScheduler): ...@@ -255,9 +234,7 @@ class HunyuanScheduler(BaseScheduler):
rope_dim_list = rope_dim_list rope_dim_list = rope_dim_list
if rope_dim_list is None: if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert ( assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed( self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list, rope_dim_list,
rope_sizes, rope_sizes,
......
import torch import torch
class BaseScheduler(): class BaseScheduler:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.step_index = 0 self.step_index = 0
......
...@@ -16,7 +16,7 @@ class WanSchedulerFeatureCaching(WanScheduler): ...@@ -16,7 +16,7 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.previous_residual_odd = None self.previous_residual_odd = None
self.use_ret_steps = self.args.use_ret_steps self.use_ret_steps = self.args.use_ret_steps
if self.args.task == 'i2v': if self.args.task == "i2v":
if self.use_ret_steps: if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480: if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [ self.coefficients = [
...@@ -56,18 +56,18 @@ class WanSchedulerFeatureCaching(WanScheduler): ...@@ -56,18 +56,18 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.ret_steps = 1 * 2 self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2 self.cutoff_steps = self.args.infer_steps * 2 - 2
elif self.args.task == 't2v': elif self.args.task == "t2v":
if self.use_ret_steps: if self.use_ret_steps:
if '1.3B' in self.args.model_path: if "1.3B" in self.args.model_path:
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if '14B' in self.args.model_path: if "14B" in self.args.model_path:
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2 self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2 self.cutoff_steps = self.args.infer_steps * 2
else: else:
if '1.3B' in self.args.model_path: if "1.3B" in self.args.model_path:
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if '14B' in self.args.model_path: if "14B" in self.args.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2 self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2 self.cutoff_steps = self.args.infer_steps * 2 - 2
...@@ -22,18 +22,11 @@ class WanScheduler(BaseScheduler): ...@@ -22,18 +22,11 @@ class WanScheduler(BaseScheduler):
self.prepare_latents(self.args.target_shape, dtype=torch.float32) self.prepare_latents(self.args.target_shape, dtype=torch.float32)
if self.args.task in ["t2v"]: if self.args.task in ["t2v"]:
self.seq_len = math.ceil( self.seq_len = math.ceil((self.args.target_shape[2] * self.args.target_shape[3]) / (self.args.patch_size[1] * self.args.patch_size[2]) * self.args.target_shape[1])
(self.args.target_shape[2] * self.args.target_shape[3])
/ (self.args.patch_size[1] * self.args.patch_size[2])
* self.args.target_shape[1]
)
elif self.args.task in ["i2v"]: elif self.args.task in ["i2v"]:
self.seq_len = ((self.args.target_video_length- 1) // self.args.vae_stride[0] + 1) * args.lat_h * args.lat_w // ( self.seq_len = ((self.args.target_video_length - 1) // self.args.vae_stride[0] + 1) * args.lat_h * args.lat_w // (args.patch_size[1] * args.patch_size[2])
args.patch_size[1] * args.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[ alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
::-1
].copy()
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
...@@ -71,9 +64,7 @@ class WanScheduler(BaseScheduler): ...@@ -71,9 +64,7 @@ class WanScheduler(BaseScheduler):
mu: Optional[Union[float, None]] = None, mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None, shift: Optional[Union[float, None]] = None,
): ):
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[ sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
:-1
]
if shift is None: if shift is None:
shift = self.shift shift = self.shift
...@@ -85,9 +76,7 @@ class WanScheduler(BaseScheduler): ...@@ -85,9 +76,7 @@ class WanScheduler(BaseScheduler):
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to( self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
device=device, dtype=torch.int64
)
assert len(self.timesteps) == self.infer_steps assert len(self.timesteps) == self.infer_steps
self.model_outputs = [ self.model_outputs = [
...@@ -108,7 +97,6 @@ class WanScheduler(BaseScheduler): ...@@ -108,7 +97,6 @@ class WanScheduler(BaseScheduler):
sample: torch.Tensor = None, sample: torch.Tensor = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None: if sample is None:
if len(args) > 1: if len(args) > 1:
...@@ -222,7 +210,6 @@ class WanScheduler(BaseScheduler): ...@@ -222,7 +210,6 @@ class WanScheduler(BaseScheduler):
order: int = None, order: int = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None: if last_sample is None:
if len(args) > 1: if len(args) > 1:
...@@ -320,11 +307,7 @@ class WanScheduler(BaseScheduler): ...@@ -320,11 +307,7 @@ class WanScheduler(BaseScheduler):
timestep = self.timesteps[self.step_index] timestep = self.timesteps[self.step_index]
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
use_corrector = ( use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
self.step_index > 0
and self.step_index - 1 not in self.disable_corrector
and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, sample=sample) model_output_convert = self.convert_model_output(model_output, sample=sample)
if use_corrector: if use_corrector:
...@@ -342,13 +325,9 @@ class WanScheduler(BaseScheduler): ...@@ -342,13 +325,9 @@ class WanScheduler(BaseScheduler):
self.model_outputs[-1] = model_output_convert self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep self.timestep_list[-1] = timestep
this_order = min( this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
self.solver_order, len(self.timesteps) - self.step_index
)
self.this_order = min( self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
this_order, self.lower_order_nums + 1
) # warmup for multistep
assert self.this_order > 0 assert self.this_order > 0
self.last_sample = sample self.last_sample = sample
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from transformers import CLIPTextModel, AutoTokenizer from transformers import CLIPTextModel, AutoTokenizer
class TextEncoderHFClipModel(): class TextEncoderHFClipModel:
def __init__(self, model_path, device): def __init__(self, model_path, device):
self.device = device self.device = device
self.model_path = model_path self.model_path = model_path
...@@ -51,6 +51,6 @@ class TextEncoderHFClipModel(): ...@@ -51,6 +51,6 @@ class TextEncoderHFClipModel():
if __name__ == "__main__": if __name__ == "__main__":
model = TextEncoderHFClipModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder_2", torch.device("cuda")) model = TextEncoderHFClipModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder_2", torch.device("cuda"))
text = 'A cat walks on the grass, realistic style.' text = "A cat walks on the grass, realistic style."
outputs = model.infer(text) outputs = model.infer(text)
print(outputs) print(outputs)
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
class TextEncoderHFLlamaModel(): class TextEncoderHFLlamaModel:
def __init__(self, model_path, device): def __init__(self, model_path, device):
self.device = device self.device = device
self.model_path = model_path self.model_path = model_path
...@@ -55,8 +55,8 @@ class TextEncoderHFLlamaModel(): ...@@ -55,8 +55,8 @@ class TextEncoderHFLlamaModel():
output_hidden_states=True, output_hidden_states=True,
) )
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start:] last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :]
attention_mask = tokens["attention_mask"][:, self.crop_start:] attention_mask = tokens["attention_mask"][:, self.crop_start :]
if args.cpu_offload: if args.cpu_offload:
self.to_cpu() self.to_cpu()
return last_hidden_state, attention_mask return last_hidden_state, attention_mask
...@@ -64,6 +64,6 @@ class TextEncoderHFLlamaModel(): ...@@ -64,6 +64,6 @@ class TextEncoderHFLlamaModel():
if __name__ == "__main__": if __name__ == "__main__":
model = TextEncoderHFLlamaModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder", torch.device("cuda")) model = TextEncoderHFLlamaModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder", torch.device("cuda"))
text = 'A cat walks on the grass, realistic style.' text = "A cat walks on the grass, realistic style."
outputs = model.infer(text) outputs = model.infer(text)
print(outputs) print(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment