Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
......@@ -7,6 +7,7 @@ from lightx2v.text2v.models.networks.hunyuan.infer.pre_infer import HunyuanPreIn
from lightx2v.text2v.models.networks.hunyuan.infer.post_infer import HunyuanPostInfer
from lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer import HunyuanTransformerInfer
from lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer import HunyuanTransformerInferFeatureCaching
# from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan
from lightx2v.attentions.distributed.ulysses.wrap import parallelize_hunyuan
......@@ -23,18 +24,18 @@ class HunyuanModel:
self._init_weights()
self._init_infer()
if self.config['parallel_attn']:
if self.config["parallel_attn"]:
parallelize_hunyuan(self)
if self.config['cpu_offload']:
if self.config["cpu_offload"]:
self.to_cpu()
def _init_infer_class(self):
self.pre_infer_class = HunyuanPreInfer
self.post_infer_class = HunyuanPostInfer
if self.config['feature_caching'] == "NoCaching":
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanTransformerInfer
elif self.config['feature_caching'] == "TaylorSeer":
elif self.config["feature_caching"] == "TaylorSeer":
self.transformer_infer_class = HunyuanTransformerInferFeatureCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -87,9 +88,5 @@ class HunyuanModel:
self.scheduler.freqs_sin,
self.scheduler.guidance,
)
img, vec = self.transformer_infer.infer(
self.transformer_weights, *pre_infer_out
)
self.scheduler.noise_pred = self.post_infer.infer(
self.post_weight, img, vec, self.scheduler.latents.shape
)
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
......@@ -7,8 +7,8 @@ class HunyuanPostWeights:
self.config = config
def load_weights(self, weight_dict):
self.final_layer_linear = MM_WEIGHT_REGISTER['Default-Force-FP32']('final_layer.linear.weight', 'final_layer.linear.bias')
self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER['Default']('final_layer.adaLN_modulation.1.weight', 'final_layer.adaLN_modulation.1.bias')
self.final_layer_linear = MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias")
self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias")
self.weight_list = [
self.final_layer_linear,
......@@ -17,7 +17,7 @@ class HunyuanPostWeights:
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......
......@@ -9,46 +9,72 @@ class HunyuanPreWeights:
self.config = config
def load_weights(self, weight_dict):
self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]('img_in.proj.weight', 'img_in.proj.bias', stride=(1, 2, 2))
self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2))
self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]('txt_in.input_embedder.weight', 'txt_in.input_embedder.bias')
self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.0.weight', 'txt_in.t_embedder.mlp.0.bias')
self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.t_embedder.mlp.2.weight', 'txt_in.t_embedder.mlp.2.bias')
self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_1.weight', 'txt_in.c_embedder.linear_1.bias')
self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]('txt_in.c_embedder.linear_2.weight', 'txt_in.c_embedder.linear_2.bias')
self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias")
self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias")
self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias")
self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias")
self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias")
self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm1.weight', 'txt_in.individual_token_refiner.blocks.0.norm1.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias')
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias')
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.norm2.weight', 'txt_in.individual_token_refiner.blocks.0.norm2.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias')
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias')
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias')
self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
)
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
)
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
)
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
)
self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm1.weight', 'txt_in.individual_token_refiner.blocks.1.norm1.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias')
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight', 'txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias')
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.norm2.weight', 'txt_in.individual_token_refiner.blocks.1.norm2.bias', eps=1e-6)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias')
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight', 'txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias')
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]('txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight', 'txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias')
self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.0.weight', 'time_in.mlp.0.bias')
self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('time_in.mlp.2.weight', 'time_in.mlp.2.bias')
self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.in_layer.weight', 'vector_in.in_layer.bias')
self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]('vector_in.out_layer.weight', 'vector_in.out_layer.bias')
self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.0.weight', 'guidance_in.mlp.0.bias')
self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]('guidance_in.mlp.2.weight', 'guidance_in.mlp.2.bias')
self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
)
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
)
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
)
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
)
self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias")
self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias")
self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias")
self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias")
self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias")
self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias")
self.weight_list = [
self.img_in_proj,
self.txt_in_input_embedder,
self.txt_in_t_embedder_mlp_0,
self.txt_in_t_embedder_mlp_2,
self.txt_in_c_embedder_linear_1,
self.txt_in_c_embedder_linear_2,
self.txt_in_individual_token_refiner_blocks_0_norm1,
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj,
......@@ -56,7 +82,6 @@ class HunyuanPreWeights:
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1,
self.txt_in_individual_token_refiner_blocks_1_norm1,
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj,
......@@ -64,7 +89,6 @@ class HunyuanPreWeights:
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1,
self.time_in_mlp_0,
self.time_in_mlp_2,
self.vector_in_in_layer,
......@@ -75,7 +99,7 @@ class HunyuanPreWeights:
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......
......@@ -41,26 +41,26 @@ class HunyuanTransformerDoubleBlock:
self.weight_list = []
def load_weights(self, weight_dict):
if self.config['do_mm_calib']:
mm_type = 'Calib'
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default'
self.img_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mod.linear.weight', f'double_blocks.{self.block_index}.img_mod.linear.bias')
self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_qkv.weight', f'double_blocks.{self.block_index}.img_attn_qkv.bias')
self.img_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_q_norm.weight', eps=1e-6)
self.img_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.img_attn_k_norm.weight', eps=1e-6)
self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_attn_proj.weight', f'double_blocks.{self.block_index}.img_attn_proj.bias')
self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc1.weight', f'double_blocks.{self.block_index}.img_mlp.fc1.bias')
self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.img_mlp.fc2.weight', f'double_blocks.{self.block_index}.img_mlp.fc2.bias')
self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mod.linear.weight', f'double_blocks.{self.block_index}.txt_mod.linear.bias')
self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_qkv.weight', f'double_blocks.{self.block_index}.txt_attn_qkv.bias')
self.txt_attn_q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_q_norm.weight', eps=1e-6)
self.txt_attn_k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'double_blocks.{self.block_index}.txt_attn_k_norm.weight', eps=1e-6)
self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_attn_proj.weight', f'double_blocks.{self.block_index}.txt_attn_proj.bias')
self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc1.weight', f'double_blocks.{self.block_index}.txt_mlp.fc1.bias')
self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f'double_blocks.{self.block_index}.txt_mlp.fc2.weight', f'double_blocks.{self.block_index}.txt_mlp.fc2.bias')
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.img_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mod.linear.weight", f"double_blocks.{self.block_index}.img_mod.linear.bias")
self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_qkv.weight", f"double_blocks.{self.block_index}.img_attn_qkv.bias")
self.img_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_q_norm.weight", eps=1e-6)
self.img_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_k_norm.weight", eps=1e-6)
self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_proj.weight", f"double_blocks.{self.block_index}.img_attn_proj.bias")
self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc1.weight", f"double_blocks.{self.block_index}.img_mlp.fc1.bias")
self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc2.weight", f"double_blocks.{self.block_index}.img_mlp.fc2.bias")
self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mod.linear.weight", f"double_blocks.{self.block_index}.txt_mod.linear.bias")
self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_qkv.weight", f"double_blocks.{self.block_index}.txt_attn_qkv.bias")
self.txt_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_q_norm.weight", eps=1e-6)
self.txt_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_k_norm.weight", eps=1e-6)
self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_proj.weight", f"double_blocks.{self.block_index}.txt_attn_proj.bias")
self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias")
self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias")
self.weight_list = [
self.img_mod,
......@@ -81,7 +81,7 @@ class HunyuanTransformerDoubleBlock:
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......@@ -102,16 +102,16 @@ class HunyuanTransformerSingleBlock:
self.weight_list = []
def load_weights(self, weight_dict):
if self.config['do_mm_calib']:
mm_type = 'Calib'
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config['mm_config'].get('mm_type', 'Default') if self.config['mm_config'] else 'Default'
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.linear1 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear1.weight', f'single_blocks.{self.block_index}.linear1.bias')
self.linear2 = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.linear2.weight', f'single_blocks.{self.block_index}.linear2.bias')
self.q_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.q_norm.weight', eps=1e-6)
self.k_norm = RMS_WEIGHT_REGISTER['sgl-kernel'](f'single_blocks.{self.block_index}.k_norm.weight', eps=1e-6)
self.modulation = MM_WEIGHT_REGISTER[mm_type](f'single_blocks.{self.block_index}.modulation.linear.weight', f'single_blocks.{self.block_index}.modulation.linear.bias')
self.linear1 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear1.weight", f"single_blocks.{self.block_index}.linear1.bias")
self.linear2 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear2.weight", f"single_blocks.{self.block_index}.linear2.bias")
self.q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6)
self.k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6)
self.modulation = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias")
self.weight_list = [
self.linear1,
......@@ -123,7 +123,7 @@ class HunyuanTransformerSingleBlock:
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......
......@@ -13,40 +13,31 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
# teacache
if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True
if (
self.scheduler.cnt < self.scheduler.ret_steps
or self.scheduler.cnt >= self.scheduler.cutoff_steps
):
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_even += rescale_func(
(
(modulated_inp - self.scheduler.previous_e0_even).abs().mean()
/ self.scheduler.previous_e0_even.abs().mean()
)
.cpu()
.item()
((modulated_inp - self.scheduler.previous_e0_even).abs().mean() / self.scheduler.previous_e0_even.abs().mean()).cpu().item()
)
if (
self.scheduler.accumulated_rel_l1_distance_even
< self.scheduler.teacache_thresh
):
if self.scheduler.accumulated_rel_l1_distance_even < self.scheduler.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.scheduler.accumulated_rel_l1_distance_even = 0
self.scheduler.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
else: # odd -> unconditon
self.scheduler.is_even = False
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
else:
should_calc_odd = True
self.scheduler.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.scheduler.coefficients)
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item())
self.scheduler.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.scheduler.previous_e0_odd).abs().mean() / self.scheduler.previous_e0_odd.abs().mean()).cpu().item()
)
if self.scheduler.accumulated_rel_l1_distance_odd < self.scheduler.teacache_thresh:
should_calc_odd = False
else:
......
......@@ -10,9 +10,7 @@ class WanPostInfer:
def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1)
norm_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
).type_as(x)
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out)
x = self.unpatchify(x, grid_sizes)
......
......@@ -6,12 +6,10 @@ import torch.cuda.amp as amp
class WanPreInfer:
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (
config["dim"] // config["num_heads"]
) % 2 == 0
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.task = config['task']
self.task = config["task"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
......@@ -25,24 +23,16 @@ class WanPreInfer:
self.text_len = config["text_len"]
def infer(self, weights, x, t, context, seq_len, clip_fea=None, y=None):
if self.task == 'i2v':
if self.task == "i2v":
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
)
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat(
[
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
]
)
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t)
embed = weights.time_embedding_0.apply(embed)
......@@ -53,25 +43,20 @@ class WanPreInfer:
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack(
[
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]
)
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.task == 'i2v':
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
return (
embed,
grid_sizes,
......
......@@ -6,7 +6,7 @@ from lightx2v.attentions import attention
class WanTransformerInfer:
def __init__(self, config):
self.config = config
self.task = config['task']
self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"]
self.num_heads = config["num_heads"]
......@@ -24,14 +24,8 @@ class WanTransformerInfer:
q_lens = torch.tensor([lq], dtype=torch.int32, device=q.device)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q = (
torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
)
cu_seqlens_k = (
torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k, lq, lk
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
......@@ -49,14 +43,10 @@ class WanTransformerInfer:
return x
def infer_block(
self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context
):
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
......@@ -87,22 +77,22 @@ class WanTransformerInfer:
)
else:
attn_out = self.parallel_attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
attention_type=self.attention_type,
q=q,
k=k,
v=v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
y = weights.self_attn_o.apply(attn_out)
x = x + y * embed0[2].squeeze(0)
norm3_out = weights.norm3.apply(x)
if self.task == 'i2v':
if self.task == "i2v":
context_img = context[:257]
context = context[257:]
......@@ -111,13 +101,11 @@ class WanTransformerInfer:
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d)
if self.task == 'i2v':
if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device)
)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k_img, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device))
img_attn_out = attention(
attention_type=self.attention_type,
......@@ -130,9 +118,7 @@ class WanTransformerInfer:
max_seqlen_kv=lk,
)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)
)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention(
attention_type=self.attention_type,
......@@ -147,9 +133,7 @@ class WanTransformerInfer:
attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out
norm2_out = torch.nn.functional.layer_norm(
x, (x.shape[1],), None, None, 1e-6
)
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y)
......
......@@ -23,12 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
......@@ -50,8 +45,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs):
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank):((cur_rank + 1) *
s_per_rank), :, :]
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
......@@ -59,9 +53,7 @@ def apply_rotary_emb(x, freqs_i):
n = x.size(1)
seq_len = freqs_i.size(0)
x_i = torch.view_as_complex(
x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
)
x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
# Apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[seq_len:]]).to(torch.bfloat16)
......@@ -85,8 +77,6 @@ def sinusoidal_embedding_1d(dim, position):
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half))
)
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1).to(torch.bfloat16)
return x
......@@ -29,10 +29,10 @@ class WanModel:
self._init_weights()
self._init_infer()
if config['parallel_attn']:
if config["parallel_attn"]:
parallelize_wan(self)
if self.config['cpu_offload']:
if self.config["cpu_offload"]:
self.to_cpu()
def _init_infer_class(self):
......@@ -43,15 +43,11 @@ class WanModel:
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferFeatureCaching
else:
raise NotImplementedError(
f"Unsupported feature_caching type: {self.config['feature_caching']}"
)
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {
key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()
}
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
......@@ -59,9 +55,7 @@ class WanModel:
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(
f"No .safetensors files found in directory: {self.model_path}"
)
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
......@@ -100,7 +94,6 @@ class WanModel:
@torch.no_grad()
def infer(self, text_encoders_output, image_encoder_output, args):
timestep = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(
......@@ -112,12 +105,8 @@ class WanModel:
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(
self.transformer_weights, grid_sizes, embed, *pre_infer_out
)
noise_pred_cond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
......@@ -133,18 +122,12 @@ class WanModel:
image_encoder_output["clip_encoder_out"],
[image_encoder_output["vae_encode_out"]],
)
x = self.transformer_infer.infer(
self.transformer_weights, grid_sizes, embed, *pre_infer_out
)
noise_pred_uncond = self.post_infer.infer(
self.post_weight, x, embed, grid_sizes
)[0]
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (
noise_pred_cond - noise_pred_uncond
)
self.scheduler.noise_pred = noise_pred_uncond + args.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
......@@ -7,18 +7,14 @@ class WanPostWeights:
self.config = config
def load_weights(self, weight_dict):
self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
self.head_modulation = weight_dict["head.modulation"]
self.head = MM_WEIGHT_REGISTER["Default"]('head.head.weight','head.head.bias')
self.head_modulation = weight_dict['head.modulation']
self.weight_list = [
self.head,
self.head_modulation
]
self.weight_list = [self.head, self.head_modulation]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......@@ -33,4 +29,4 @@ class WanPostWeights:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda()
else:
mm_weight.cuda()
\ No newline at end of file
mm_weight.cuda()
......@@ -4,6 +4,7 @@ from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
class WanPreWeights:
def __init__(self, config):
self.in_dim = config["in_dim"]
......@@ -12,18 +13,16 @@ class WanPreWeights:
self.config = config
def load_weights(self, weight_dict):
self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size)
self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]('patch_embedding.weight', 'patch_embedding.bias', stride=self.patch_size)
self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]('text_embedding.0.weight', 'text_embedding.0.bias')
self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]('text_embedding.2.weight', 'text_embedding.2.bias')
self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]('time_embedding.0.weight', 'time_embedding.0.bias')
self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]('time_embedding.2.weight', 'time_embedding.2.bias')
self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]('time_projection.1.weight', 'time_projection.1.bias')
self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias")
self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias")
self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias")
self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias")
self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias")
self.weight_list = [
self.patch_embedding,
self.text_embedding_0,
self.text_embedding_2,
self.time_embedding_0,
......@@ -31,11 +30,11 @@ class WanPreWeights:
self.time_projection_1,
]
if 'img_emb.proj.0.weight' in weight_dict.keys():
self.proj_0 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.0.weight', 'img_emb.proj.0.bias', eps=1e-5)
self.proj_1 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.1.weight', 'img_emb.proj.1.bias')
self.proj_3 = MM_WEIGHT_REGISTER["Default"]('img_emb.proj.3.weight', 'img_emb.proj.3.bias')
self.proj_4 = LN_WEIGHT_REGISTER["Default"]('img_emb.proj.4.weight', 'img_emb.proj.4.bias', eps=1e-5)
if "img_emb.proj.0.weight" in weight_dict.keys():
self.proj_0 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5)
self.proj_1 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias")
self.proj_3 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias")
self.proj_4 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5)
self.weight_list.append(self.proj_0)
self.weight_list.append(self.proj_1)
self.weight_list.append(self.proj_3)
......@@ -43,7 +42,7 @@ class WanPreWeights:
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......@@ -54,4 +53,4 @@ class WanPreWeights:
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cuda()
\ No newline at end of file
mm_weight.to_cuda()
......@@ -7,17 +7,15 @@ from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
class WanTransformerWeights:
def __init__(self, config):
self.blocks_num = config["num_layers"]
self.task = config['task']
self.task = config["task"]
self.config = config
if config['do_mm_calib']:
self.mm_type = 'Calib'
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config['mm_config'].get('mm_type', 'Default') if config['mm_config'] else 'Default'
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
def load_weights(self, weight_dict):
self.blocks_weights = [
WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)
]
self.blocks_weights = [WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
......@@ -26,7 +24,7 @@ class WanTransformerWeights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
for block in self.blocks_weights:
block.to_cuda()
......@@ -38,25 +36,24 @@ class WanTransformerAttentionBlock:
self.config = config
def load_weights(self, weight_dict):
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")
self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")
self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.bias")
self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.bias")
self.self_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight")
self.self_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight")
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.q.weight',f'blocks.{self.block_index}.self_attn.q.bias')
self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.k.weight',f'blocks.{self.block_index}.self_attn.k.bias')
self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.v.weight',f'blocks.{self.block_index}.self_attn.v.bias')
self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.self_attn.o.weight',f'blocks.{self.block_index}.self_attn.o.bias')
self.self_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_q.weight')
self.self_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.self_attn.norm_k.weight')
self.norm3 = LN_WEIGHT_REGISTER['Default'](f'blocks.{self.block_index}.norm3.weight',f'blocks.{self.block_index}.norm3.bias',eps = 1e-6)
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.q.weight',f'blocks.{self.block_index}.cross_attn.q.bias')
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k.weight',f'blocks.{self.block_index}.cross_attn.k.bias')
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v.weight',f'blocks.{self.block_index}.cross_attn.v.bias')
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.o.weight',f'blocks.{self.block_index}.cross_attn.o.bias')
self.cross_attn_norm_q = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_q.weight')
self.cross_attn_norm_k = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k.weight')
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.0.weight',f'blocks.{self.block_index}.ffn.0.bias')
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.ffn.2.weight',f'blocks.{self.block_index}.ffn.2.bias')
self.modulation = weight_dict[f'blocks.{self.block_index}.modulation']
self.norm3 = LN_WEIGHT_REGISTER["Default"](f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.bias", eps=1e-6)
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.bias")
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.bias")
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.bias")
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.bias")
self.cross_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight")
self.cross_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight")
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")
self.modulation = weight_dict[f"blocks.{self.block_index}.modulation"]
self.weight_list = [
self.self_attn_q,
......@@ -77,18 +74,18 @@ class WanTransformerAttentionBlock:
self.modulation,
]
if self.task == 'i2v':
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.k_img.weight',f'blocks.{self.block_index}.cross_attn.k_img.bias')
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f'blocks.{self.block_index}.cross_attn.v_img.weight',f'blocks.{self.block_index}.cross_attn.v_img.bias')
if self.task == "i2v":
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias")
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias")
# self.cross_attn_norm_k_img_weight = weight_dict[f'blocks.{self.block_index}.cross_attn.norm_k_img.weight']
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER['sgl-kernel'](f'blocks.{self.block_index}.cross_attn.norm_k_img.weight')
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight")
self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img)
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config['mm_config'])
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
......@@ -103,4 +100,4 @@ class WanTransformerAttentionBlock:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
else:
mm_weight.cuda()
\ No newline at end of file
mm_weight.cuda()
......@@ -2,192 +2,194 @@ import torch
from ..scheduler import HunyuanScheduler
def cache_init(num_steps, model_kwargs=None):
'''
def cache_init(num_steps, model_kwargs=None):
"""
Initialization for cache.
'''
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1]={}
cache_index[-1]={}
cache_index['layer_index']={}
cache_dic['attn_map'] = {}
cache_dic['attn_map'][-1] = {}
cache_dic['attn_map'][-1]['double_stream'] = {}
cache_dic['attn_map'][-1]['single_stream'] = {}
cache_dic['k-norm'] = {}
cache_dic['k-norm'][-1] = {}
cache_dic['k-norm'][-1]['double_stream'] = {}
cache_dic['k-norm'][-1]['single_stream'] = {}
cache_dic['v-norm'] = {}
cache_dic['v-norm'][-1] = {}
cache_dic['v-norm'][-1]['double_stream'] = {}
cache_dic['v-norm'][-1]['single_stream'] = {}
cache_dic['cross_attn_map'] = {}
cache_dic['cross_attn_map'][-1] = {}
cache[-1]['double_stream']={}
cache[-1]['single_stream']={}
cache_dic['cache_counter'] = 0
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]['double_stream'][j] = {}
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['total'] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['attn_map'][-1]['double_stream'][j]['img_mlp'] = {}
cache_dic['k-norm'][-1]['double_stream'][j] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['k-norm'][-1]['double_stream'][j]['img_mlp'] = {}
cache_dic['v-norm'][-1]['double_stream'][j] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['txt_mlp'] = {}
cache_dic['v-norm'][-1]['double_stream'][j]['img_mlp'] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]['single_stream'][j] = {}
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j] = {}
cache_dic['attn_map'][-1]['single_stream'][j]['total'] = {}
cache_dic['k-norm'][-1]['single_stream'][j] = {}
cache_dic['k-norm'][-1]['single_stream'][j]['total'] = {}
cache_dic['v-norm'][-1]['single_stream'][j] = {}
cache_dic['v-norm'][-1]['single_stream'][j]['total'] = {}
cache_dic['taylor_cache'] = False
cache_dic['duca'] = False
cache_dic['test_FLOPs'] = False
mode = 'Taylor'
if mode == 'original':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 1
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
elif mode == 'ToCa':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.10
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
cache_dic['duca'] = False
elif mode == 'DuCa':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.10
cache_dic['fresh_threshold'] = 5
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['max_order'] = 0
cache_dic['first_enhance'] = 1
cache_dic['duca'] = True
elif mode == 'Taylor':
cache_dic['cache_type'] = 'random'
cache_dic['cache_index'] = cache_index
cache_dic['cache'] = cache
cache_dic['fresh_ratio_schedule'] = 'ToCa'
cache_dic['fresh_ratio'] = 0.0
cache_dic['fresh_threshold'] = 5
cache_dic['max_order'] = 1
cache_dic['force_fresh'] = 'global'
cache_dic['soft_fresh_weight'] = 0.0
cache_dic['taylor_cache'] = True
cache_dic['first_enhance'] = 1
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 5
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current['num_steps'] = num_steps
current['activated_steps'] = [0]
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic['fresh_ratio'] == 0:
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps'])
threshold = torch.round(cache_dic['fresh_threshold'] / step_factor)
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic['cal_threshold'] = threshold
#return threshold
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
'''
"""
Determine calculation type for this step
'''
if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']):
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = (current['step'] == 0)
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = (current['step'] < cache_dic['first_enhance'])
#first_step = (current['step'] <= 3)
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic['force_fresh']
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic['cal_threshold']
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic['fresh_threshold']
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ):
current['type'] = 'full'
cache_dic['cache_counter'] = 0
current['activated_steps'].append(current['step'])
#current['activated_times'].append(current['t'])
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif (cache_dic['taylor_cache']):
cache_dic['cache_counter'] += 1
current['type'] = 'taylor_cache'
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic['cache_counter'] += 1
if (cache_dic['duca']):
if (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current['type'] = 'ToCa'
# 'cache_noise' 'ToCa' 'FORA'
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current['type'] = 'aggressive'
current["type"] = "aggressive"
else:
current['type'] = 'ToCa'
current["type"] = "ToCa"
#if current['step'] < 25:
# if current['step'] < 25:
# current['type'] = 'FORA'
#else:
# else:
# current['type'] = 'aggressive'
######################################################################
#if (current['step'] in [3,2,1,0]):
# current['type'] = 'full'
# if (current['step'] in [3,2,1,0]):
# current['type'] = 'full'
class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
......@@ -197,5 +199,5 @@ class HunyuanSchedulerFeatureCaching(HunyuanScheduler):
def step_pre(self, step_index):
super().step_pre(step_index)
self.current['step'] = step_index
self.current["step"] = step_index
cal_type(self.cache_dic, self.current)
......@@ -13,6 +13,7 @@ def _to_tuple(x, dim=2):
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
......@@ -49,9 +50,7 @@ def get_1d_rotary_pos_embed(
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
......@@ -59,9 +58,7 @@ def get_1d_rotary_pos_embed(
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
......@@ -109,6 +106,7 @@ def get_meshgrid_nd(start, *args, dim=2):
return grid
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
......@@ -137,25 +135,19 @@ def get_nd_rotary_pos_embed(
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
......@@ -179,12 +171,10 @@ def get_nd_rotary_pos_embed(
return emb
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to(
dtype=torch.bfloat16, device=device
)
timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device)
return timesteps, sigmas
......@@ -193,17 +183,17 @@ class HunyuanScheduler(BaseScheduler):
super().__init__(args)
self.infer_steps = self.args.infer_steps
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device('cuda'))
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator('cuda').manual_seed(seed) for seed in [42]]
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]]
self.noise_pred = None
self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device('cuda')) * 1000.0
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
def step_post(self):
sample = self.latents.to(torch.float32)
......@@ -212,7 +202,7 @@ class HunyuanScheduler(BaseScheduler):
self.latents = prev_sample
def prepare_latents(self, shape, dtype):
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device('cuda'), dtype=dtype)
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3
......@@ -232,22 +222,11 @@ class HunyuanScheduler(BaseScheduler):
latents_size = [video_length, height // 8, width // 8]
if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
rope_sizes = [s // patch_size for s in latents_size]
elif isinstance(patch_size, list):
assert all(
s % patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // patch_size[idx] for idx, s in enumerate(latents_size)
]
assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
......@@ -255,9 +234,7 @@ class HunyuanScheduler(BaseScheduler):
rope_dim_list = rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
......
import torch
class BaseScheduler():
class BaseScheduler:
def __init__(self, args):
self.args = args
self.step_index = 0
self.latents = None
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
......@@ -16,7 +16,7 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.previous_residual_odd = None
self.use_ret_steps = self.args.use_ret_steps
if self.args.task == 'i2v':
if self.args.task == "i2v":
if self.use_ret_steps:
if self.args.target_width == 480 or self.args.target_height == 480:
self.coefficients = [
......@@ -56,18 +56,18 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
elif self.args.task == 't2v':
elif self.args.task == "t2v":
if self.use_ret_steps:
if '1.3B' in self.args.model_path:
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
if '14B' in self.args.model_path:
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
if "1.3B" in self.args.model_path:
self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02]
if "14B" in self.args.model_path:
self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01]
self.ret_steps = 5 * 2
self.cutoff_steps = self.args.infer_steps * 2
else:
if '1.3B' in self.args.model_path:
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
if '14B' in self.args.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
if "1.3B" in self.args.model_path:
self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01]
if "14B" in self.args.model_path:
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
\ No newline at end of file
self.cutoff_steps = self.args.infer_steps * 2 - 2
......@@ -20,20 +20,13 @@ class WanScheduler(BaseScheduler):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.args.seed)
self.prepare_latents(self.args.target_shape, dtype=torch.float32)
if self.args.task in ["t2v"]:
self.seq_len = math.ceil(
(self.args.target_shape[2] * self.args.target_shape[3])
/ (self.args.patch_size[1] * self.args.patch_size[2])
* self.args.target_shape[1]
)
self.seq_len = math.ceil((self.args.target_shape[2] * self.args.target_shape[3]) / (self.args.patch_size[1] * self.args.patch_size[2]) * self.args.target_shape[1])
elif self.args.task in ["i2v"]:
self.seq_len = ((self.args.target_video_length- 1) // self.args.vae_stride[0] + 1) * args.lat_h * args.lat_w // (
args.patch_size[1] * args.patch_size[2])
self.seq_len = ((self.args.target_video_length - 1) // self.args.vae_stride[0] + 1) * args.lat_h * args.lat_w // (args.patch_size[1] * args.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[
::-1
].copy()
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
......@@ -71,9 +64,7 @@ class WanScheduler(BaseScheduler):
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[
:-1
]
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
if shift is None:
shift = self.shift
......@@ -85,9 +76,7 @@ class WanScheduler(BaseScheduler):
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64
)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
assert len(self.timesteps) == self.infer_steps
self.model_outputs = [
......@@ -108,7 +97,6 @@ class WanScheduler(BaseScheduler):
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
......@@ -222,7 +210,6 @@ class WanScheduler(BaseScheduler):
order: int = None,
**kwargs,
) -> torch.Tensor:
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
......@@ -320,11 +307,7 @@ class WanScheduler(BaseScheduler):
timestep = self.timesteps[self.step_index]
sample = self.latents.to(torch.float32)
use_corrector = (
self.step_index > 0
and self.step_index - 1 not in self.disable_corrector
and self.last_sample is not None
)
use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
model_output_convert = self.convert_model_output(model_output, sample=sample)
if use_corrector:
......@@ -342,13 +325,9 @@ class WanScheduler(BaseScheduler):
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
this_order = min(
self.solver_order, len(self.timesteps) - self.step_index
)
this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
self.this_order = min(
this_order, self.lower_order_nums + 1
) # warmup for multistep
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
......
......@@ -2,7 +2,7 @@ import torch
from transformers import CLIPTextModel, AutoTokenizer
class TextEncoderHFClipModel():
class TextEncoderHFClipModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
......@@ -51,6 +51,6 @@ class TextEncoderHFClipModel():
if __name__ == "__main__":
model = TextEncoderHFClipModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder_2", torch.device("cuda"))
text = 'A cat walks on the grass, realistic style.'
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
......@@ -2,7 +2,7 @@ import torch
from transformers import AutoModel, AutoTokenizer
class TextEncoderHFLlamaModel():
class TextEncoderHFLlamaModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
......@@ -55,8 +55,8 @@ class TextEncoderHFLlamaModel():
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start:]
attention_mask = tokens["attention_mask"][:, self.crop_start:]
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :]
attention_mask = tokens["attention_mask"][:, self.crop_start :]
if args.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
......@@ -64,6 +64,6 @@ class TextEncoderHFLlamaModel():
if __name__ == "__main__":
model = TextEncoderHFLlamaModel("/mnt/nvme0/yongyang/projects/hy/HunyuanVideo/ckpts/text_encoder", torch.device("cuda"))
text = 'A cat walks on the grass, realistic style.'
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment