Commit ae96fdbf authored by helloyongyang's avatar helloyongyang
Browse files

Update weight modules. Simplify code.

parent 3996d421
......@@ -8,6 +8,8 @@
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
......
class WeightModule:
def __init__(self):
self._modules = {}
self._parameters = {}
def add_module(self, name, module):
self._modules[name] = module
setattr(self, name, module)
def register_parameter(self, name, param):
self._parameters[name] = param
setattr(self, name, param)
def load(self, weight_dict):
for _, module in self._modules.items():
if hasattr(module, "set_config"):
module.set_config(self.config["mm_config"])
if hasattr(module, "load"):
module.load(weight_dict)
for _, parameter in self._parameters.items():
if hasattr(parameter, "set_config"):
parameter.set_config(self.config["mm_config"])
if hasattr(parameter, "load"):
parameter.load(weight_dict)
def state_dict(self, destination=None, prefix=""):
if destination is None:
destination = {}
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param.detach().cpu().clone()
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + ".")
return destination
def named_parameters(self, prefix=""):
for name, param in self._parameters.items():
if param is not None:
yield prefix + name, param
for name, module in self._modules.items():
if module is not None:
yield from module.named_parameters(prefix + name + ".")
def to_cpu(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cpu"):
self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
def to_cuda(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda()
def to_cpu_sync(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "to"):
self._parameters[name] = param.to("cpu", non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu_sync"):
module.to_cpu_sync()
def to_cuda_sync(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda_sync"):
module.to_cuda_sync()
class WeightModuleList(WeightModule):
def __init__(self, modules=None):
super().__init__()
self._list = []
if modules is not None:
for idx, module in enumerate(modules):
self.append(module)
def append(self, module):
idx = len(self._list)
self._list.append(module)
self.add_module(str(idx), module)
def __getitem__(self, idx):
return self._list[idx]
def __len__(self):
return len(self._list)
def __iter__(self):
return iter(self._list)
from .mm import *
from .norm import *
from .conv import *
from .tensor import *
from .rms_norm_weight import *
from .layer_norm_weight import *
from .tensor import DefaultTensor
from lightx2v.utils.registry_factory import TENSOR_REGISTER
@TENSOR_REGISTER("Default")
class DefaultTensor:
def __init__(self, tensor_name):
self.tensor_name = tensor_name
def load(self, weight_dict):
self.tensor = weight_dict[self.tensor_name]
def to_cpu(self, non_blocking=False):
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.tensor = self.tensor.cuda(non_blocking=non_blocking)
......@@ -36,14 +36,14 @@ class HunyuanTransformerInfer:
for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[0] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0] = weights.double_blocks[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block(self.double_weights_stream_mgr.active_weights[0], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights)
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks)
self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0)
......@@ -55,12 +55,12 @@ class HunyuanTransformerInfer:
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(self.single_weights_stream_mgr.active_weights[0], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights)
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks)
self.single_weights_stream_mgr.swap_weights()
torch.cuda.empty_cache()
......@@ -72,12 +72,12 @@ class HunyuanTransformerInfer:
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt = self.infer_double_block(weights.double_blocks[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = self.infer_single_block(weights.single_blocks[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img = x[:img_seq_len, ...]
return img, vec
......
......@@ -64,9 +64,9 @@ class HunyuanModel:
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
self.pre_weight.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
self.pre_weight.load(weight_dict)
self.post_weight.load(weight_dict)
self.transformer_weights.load(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.modules.weight_module import WeightModule
class HunyuanPostWeights:
class HunyuanPostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
def load_weights(self, weight_dict):
self.final_layer_linear = MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias")
self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias")
self.weight_list = [
self.final_layer_linear,
self.final_layer_adaLN_modulation_1,
]
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cuda()
self.add_module("final_layer_linear", MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias"))
self.add_module("final_layer_adaLN_modulation_1", MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias"))
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
from lightx2v.common.modules.weight_module import WeightModule
class HunyuanPreWeights:
class HunyuanPreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
def load_weights(self, weight_dict):
self.img_in_proj = CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2))
self.add_module("img_in_proj", CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2)))
self.txt_in_input_embedder = MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias")
self.txt_in_t_embedder_mlp_0 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias")
self.txt_in_t_embedder_mlp_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias")
self.txt_in_c_embedder_linear_1 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias")
self.txt_in_c_embedder_linear_2 = MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias")
self.add_module("txt_in_input_embedder", MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias"))
self.add_module("txt_in_t_embedder_mlp_0", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias"))
self.add_module("txt_in_t_embedder_mlp_2", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias"))
self.add_module("txt_in_c_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias"))
self.add_module("txt_in_c_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias"))
self.txt_in_individual_token_refiner_blocks_0_norm1 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6
self.add_module(
"txt_in_individual_token_refiner_blocks_0_norm1",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6),
)
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"),
)
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_0_self_attn_proj",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"),
)
self.txt_in_individual_token_refiner_blocks_0_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6
self.add_module(
"txt_in_individual_token_refiner_blocks_0_norm2",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6),
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_0_mlp_fc1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"),
)
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_0_mlp_fc2",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"),
)
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"),
)
self.txt_in_individual_token_refiner_blocks_1_norm1 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6
self.add_module(
"txt_in_individual_token_refiner_blocks_1_norm1",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6),
)
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"),
)
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_1_self_attn_proj",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"),
)
self.txt_in_individual_token_refiner_blocks_1_norm2 = LN_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6
self.add_module(
"txt_in_individual_token_refiner_blocks_1_norm2",
LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6),
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_1_mlp_fc1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"),
)
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_1_mlp_fc2",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"),
)
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
self.add_module(
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1",
MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"),
)
self.time_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias")
self.time_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias")
self.vector_in_in_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias")
self.vector_in_out_layer = MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias")
self.guidance_in_mlp_0 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias")
self.guidance_in_mlp_2 = MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias")
self.weight_list = [
self.img_in_proj,
self.txt_in_input_embedder,
self.txt_in_t_embedder_mlp_0,
self.txt_in_t_embedder_mlp_2,
self.txt_in_c_embedder_linear_1,
self.txt_in_c_embedder_linear_2,
self.txt_in_individual_token_refiner_blocks_0_norm1,
self.txt_in_individual_token_refiner_blocks_0_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_0_self_attn_proj,
self.txt_in_individual_token_refiner_blocks_0_norm2,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_0_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1,
self.txt_in_individual_token_refiner_blocks_1_norm1,
self.txt_in_individual_token_refiner_blocks_1_self_attn_qkv,
self.txt_in_individual_token_refiner_blocks_1_self_attn_proj,
self.txt_in_individual_token_refiner_blocks_1_norm2,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc1,
self.txt_in_individual_token_refiner_blocks_1_mlp_fc2,
self.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1,
self.time_in_mlp_0,
self.time_in_mlp_2,
self.vector_in_in_layer,
self.vector_in_out_layer,
self.guidance_in_mlp_0,
self.guidance_in_mlp_2,
]
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.to_cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.to_cuda()
self.add_module("time_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias"))
self.add_module("time_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias"))
self.add_module("vector_in_in_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias"))
self.add_module("vector_in_out_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias"))
self.add_module("guidance_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias"))
self.add_module("guidance_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias"))
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.norm.rms_norm_weight import RMS_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
class HunyuanTransformerWeights:
class HunyuanTransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.init()
def init(self):
self.double_blocks_num = 20
self.single_blocks_num = 40
def load_weights(self, weight_dict):
self.double_blocks_weights = [HunyuanTransformerDoubleBlock(i, self.config) for i in range(self.double_blocks_num)]
self.single_blocks_weights = [HunyuanTransformerSingleBlock(i, self.config) for i in range(self.single_blocks_num)]
for double_block in self.double_blocks_weights:
double_block.load_weights(weight_dict)
for single_block in self.single_blocks_weights:
single_block.load_weights(weight_dict)
def to_cpu(self):
for double_block in self.double_blocks_weights:
double_block.to_cpu()
for single_block in self.single_blocks_weights:
single_block.to_cpu()
self.add_module("double_blocks", WeightModuleList([HunyuanTransformerDoubleBlock(i, self.config) for i in range(self.double_blocks_num)]))
self.add_module("single_blocks", WeightModuleList([HunyuanTransformerSingleBlock(i, self.config) for i in range(self.single_blocks_num)]))
def to_cuda(self):
for double_block in self.double_blocks_weights:
double_block.to_cuda()
for single_block in self.single_blocks_weights:
single_block.to_cuda()
class HunyuanTransformerDoubleBlock:
class HunyuanTransformerDoubleBlock(WeightModule):
def __init__(self, block_index, config):
super().__init__()
self.block_index = block_index
self.config = config
self.weight_list = []
def load_weights(self, weight_dict):
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.img_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mod.linear.weight", f"double_blocks.{self.block_index}.img_mod.linear.bias")
self.img_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_qkv.weight", f"double_blocks.{self.block_index}.img_attn_qkv.bias")
self.img_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_q_norm.weight", eps=1e-6)
self.img_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_k_norm.weight", eps=1e-6)
self.img_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_proj.weight", f"double_blocks.{self.block_index}.img_attn_proj.bias")
self.img_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc1.weight", f"double_blocks.{self.block_index}.img_mlp.fc1.bias")
self.img_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc2.weight", f"double_blocks.{self.block_index}.img_mlp.fc2.bias")
self.txt_mod = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mod.linear.weight", f"double_blocks.{self.block_index}.txt_mod.linear.bias")
self.txt_attn_qkv = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_qkv.weight", f"double_blocks.{self.block_index}.txt_attn_qkv.bias")
self.txt_attn_q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_q_norm.weight", eps=1e-6)
self.txt_attn_k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_k_norm.weight", eps=1e-6)
self.txt_attn_proj = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_proj.weight", f"double_blocks.{self.block_index}.txt_attn_proj.bias")
self.txt_mlp_fc1 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias")
self.txt_mlp_fc2 = MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias")
self.weight_list = [
self.img_mod,
self.img_attn_qkv,
self.img_attn_q_norm,
self.img_attn_k_norm,
self.img_attn_proj,
self.img_mlp_fc1,
self.img_mlp_fc2,
self.txt_mod,
self.txt_attn_qkv,
self.txt_attn_q_norm,
self.txt_attn_k_norm,
self.txt_attn_proj,
self.txt_mlp_fc1,
self.txt_mlp_fc2,
]
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
self.add_module("img_mod", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mod.linear.weight", f"double_blocks.{self.block_index}.img_mod.linear.bias"))
self.add_module("img_attn_qkv", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_qkv.weight", f"double_blocks.{self.block_index}.img_attn_qkv.bias"))
self.add_module("img_attn_q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_q_norm.weight", eps=1e-6))
self.add_module("img_attn_k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.img_attn_k_norm.weight", eps=1e-6))
self.add_module("img_attn_proj", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_attn_proj.weight", f"double_blocks.{self.block_index}.img_attn_proj.bias"))
self.add_module("img_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc1.weight", f"double_blocks.{self.block_index}.img_mlp.fc1.bias"))
self.add_module("img_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.img_mlp.fc2.weight", f"double_blocks.{self.block_index}.img_mlp.fc2.bias"))
def to_cpu_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
self.add_module("txt_mod", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mod.linear.weight", f"double_blocks.{self.block_index}.txt_mod.linear.bias"))
self.add_module("txt_attn_qkv", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_qkv.weight", f"double_blocks.{self.block_index}.txt_attn_qkv.bias"))
self.add_module("txt_attn_q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_q_norm.weight", eps=1e-6))
self.add_module("txt_attn_k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"double_blocks.{self.block_index}.txt_attn_k_norm.weight", eps=1e-6))
self.add_module("txt_attn_proj", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_attn_proj.weight", f"double_blocks.{self.block_index}.txt_attn_proj.bias"))
self.add_module("txt_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias"))
self.add_module("txt_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias"))
def to_cuda_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
class HunyuanTransformerSingleBlock:
class HunyuanTransformerSingleBlock(WeightModule):
def __init__(self, block_index, config):
super().__init__()
self.block_index = block_index
self.config = config
self.weight_list = []
def load_weights(self, weight_dict):
if self.config["do_mm_calib"]:
mm_type = "Calib"
else:
mm_type = self.config["mm_config"].get("mm_type", "Default") if self.config["mm_config"] else "Default"
self.linear1 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear1.weight", f"single_blocks.{self.block_index}.linear1.bias")
self.linear2 = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear2.weight", f"single_blocks.{self.block_index}.linear2.bias")
self.q_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6)
self.k_norm = RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6)
self.modulation = MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias")
self.weight_list = [
self.linear1,
self.linear2,
self.q_norm,
self.k_norm,
self.modulation,
]
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
def to_cpu_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
self.add_module("linear1", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear1.weight", f"single_blocks.{self.block_index}.linear1.bias"))
self.add_module("linear2", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.linear2.weight", f"single_blocks.{self.block_index}.linear2.bias"))
self.add_module("q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6))
self.add_module("k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6))
self.add_module("modulation", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias"))
......@@ -13,10 +13,10 @@ class WanPostInfer:
def infer(self, weights, x, e, grid_sizes):
if e.dim() == 2:
modulation = weights.head_modulation # 1, 2, dim
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.unsqueeze(2) # 1, 2, seq, dim
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
......
......@@ -42,7 +42,7 @@ class WanTransformerInfer:
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks_weights[0]
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
......@@ -58,7 +58,7 @@ class WanTransformerInfer:
)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks_weights)
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
self.weights_stream_mgr.swap_weights()
return x
......@@ -66,7 +66,7 @@ class WanTransformerInfer:
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(
weights.blocks_weights[block_idx],
weights.blocks[block_idx],
grid_sizes,
embed,
x,
......@@ -79,12 +79,12 @@ class WanTransformerInfer:
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if embed0.dim() == 3:
modulation = weights.modulation.unsqueeze(2) # 1, 6, 1, dim
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......
......@@ -84,9 +84,9 @@ class WanModel:
self.post_weight = self.post_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
# load weights
self.pre_weight.load_weights(self.original_weight_dict)
self.post_weight.load_weights(self.original_weight_dict)
self.transformer_weights.load_weights(self.original_weight_dict)
self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
......@@ -109,9 +109,6 @@ class WanModel:
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
def do_classifier_free_guidance(self) -> bool:
return self.config.sample_guide_scale > 1
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
......@@ -128,7 +125,7 @@ class WanModel:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.do_classifier_free_guidance():
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, TENSOR_REGISTER
from lightx2v.common.modules.weight_module import WeightModule
class WanPostWeights:
class WanPostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
def load_weights(self, weight_dict):
self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
self.head_modulation = weight_dict["head.modulation"]
self.weight_list = [self.head]
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
if self.config["cpu_offload"]:
weight.to_cpu()
self.head_modulation = self.head_modulation.cpu()
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cpu()
self.head_modulation = self.head_modulation.cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cuda()
self.head_modulation = self.head_modulation.cuda()
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
from lightx2v.common.modules.weight_module import WeightModule
class WanPreWeights:
class WanPreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.in_dim = config["in_dim"]
self.dim = config["dim"]
self.patch_size = (1, 2, 2)
self.config = config
def load_weights(self, weight_dict):
self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size)
self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias")
self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias")
self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias")
self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias")
self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias")
self.weight_list = [
self.patch_embedding,
self.text_embedding_0,
self.text_embedding_2,
self.time_embedding_0,
self.time_embedding_2,
self.time_projection_1,
]
if "img_emb.proj.0.weight" in weight_dict.keys():
self.proj_0 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5)
self.proj_1 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias")
self.proj_3 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias")
self.proj_4 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5)
self.weight_list.append(self.proj_0)
self.weight_list.append(self.proj_1)
self.weight_list.append(self.proj_3)
self.weight_list.append(self.proj_4)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
if self.config["cpu_offload"]:
weight.to_cpu()
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.to_cuda()
self.add_module("patch_embedding", CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size))
self.add_module("text_embedding_0", MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"))
self.add_module("text_embedding_2", MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"))
self.add_module("time_embedding_0", MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"))
self.add_module("time_embedding_2", MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"))
self.add_module("time_projection_1", MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"))
if config.task == "i2v":
self.add_module("proj_0", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5))
self.add_module("proj_1", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"))
self.add_module("proj_3", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"))
self.add_module("proj_4", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5))
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
class WanTransformerWeights:
class WanTransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
......@@ -13,99 +12,39 @@ class WanTransformerWeights:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks)
def load_weights(self, weight_dict):
self.blocks_weights = [WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
def to_cpu(self):
for block in self.blocks_weights:
block.to_cpu()
def to_cuda(self):
for block in self.blocks_weights:
block.to_cuda()
class WanTransformerAttentionBlock:
class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
def load_weights(self, weight_dict):
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")
self.self_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")
self.self_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.bias")
self.self_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.bias")
self.self_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight")
self.self_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight")
self.norm3 = LN_WEIGHT_REGISTER["Default"](f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.bias", eps=1e-6)
self.cross_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.bias")
self.cross_attn_k = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.bias")
self.cross_attn_v = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.bias")
self.cross_attn_o = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.bias")
self.cross_attn_norm_q = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight")
self.cross_attn_norm_k = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight")
self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias"))
self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
self.add_module("self_attn_v", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.v.weight", f"blocks.{self.block_index}.self_attn.v.bias"))
self.add_module("self_attn_o", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.o.weight", f"blocks.{self.block_index}.self_attn.o.bias"))
self.add_module("self_attn_norm_q", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_q.weight"))
self.add_module("self_attn_norm_k", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.self_attn.norm_k.weight"))
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")
self.modulation = weight_dict[f"blocks.{self.block_index}.modulation"]
self.add_module("norm3", LN_WEIGHT_REGISTER["Default"](f"blocks.{self.block_index}.norm3.weight", f"blocks.{self.block_index}.norm3.bias", eps=1e-6))
self.add_module("cross_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.q.weight", f"blocks.{self.block_index}.cross_attn.q.bias"))
self.add_module("cross_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k.weight", f"blocks.{self.block_index}.cross_attn.k.bias"))
self.add_module("cross_attn_v", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v.weight", f"blocks.{self.block_index}.cross_attn.v.bias"))
self.add_module("cross_attn_o", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.o.weight", f"blocks.{self.block_index}.cross_attn.o.bias"))
self.add_module("cross_attn_norm_q", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_q.weight"))
self.add_module("cross_attn_norm_k", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k.weight"))
self.weight_list = [
self.self_attn_q,
self.self_attn_k,
self.self_attn_v,
self.self_attn_o,
self.self_attn_norm_q,
self.self_attn_norm_k,
self.norm3,
self.cross_attn_q,
self.cross_attn_k,
self.cross_attn_v,
self.cross_attn_o,
self.cross_attn_norm_q,
self.cross_attn_norm_k,
self.ffn_0,
self.ffn_2,
]
self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias"))
self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias"))
if self.task == "i2v":
self.cross_attn_k_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias")
self.cross_attn_v_img = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias")
self.cross_attn_norm_k_img = RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight")
self.weight_list.append(self.cross_attn_k_img)
self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
self.modulation = self.modulation.cpu()
def to_cuda(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
self.modulation = self.modulation.cuda()
def to_cpu_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
self.modulation = self.modulation.to("cpu", non_blocking=True)
self.add_module("cross_attn_k_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.k_img.weight", f"blocks.{self.block_index}.cross_attn.k_img.bias"))
self.add_module("cross_attn_v_img", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.cross_attn.v_img.weight", f"blocks.{self.block_index}.cross_attn.v_img.bias"))
self.add_module("cross_attn_norm_k_img", RMS_WEIGHT_REGISTER["sgl-kernel"](f"blocks.{self.block_index}.cross_attn.norm_k_img.weight"))
def to_cuda_sync(self):
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
self.modulation = self.modulation.cuda(non_blocking=True)
self.register_parameter("modulation", TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"))
......@@ -50,4 +50,6 @@ LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register()
CONV2D_WEIGHT_REGISTER = Register()
TENSOR_REGISTER = Register()
RUNNER_REGISTER = Register()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment