post_weights.py 570 Bytes
Newer Older
1
from lightx2v.common.modules.weight_module import WeightModule
PengGao's avatar
PengGao committed
2
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
helloyongyang's avatar
helloyongyang committed
3
4


5
class HunyuanPostWeights(WeightModule):
helloyongyang's avatar
helloyongyang committed
6
    def __init__(self, config):
7
        super().__init__()
helloyongyang's avatar
helloyongyang committed
8
9
        self.config = config

10
11
        self.add_module("final_layer_linear", MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias"))
        self.add_module("final_layer_adaLN_modulation_1", MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias"))