post_weights.py 1.13 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate


class HunyuanPostWeights:
    def __init__(self, config):
        self.config = config

    def load_weights(self, weight_dict):
Dongz's avatar
Dongz committed
10
11
        self.final_layer_linear = MM_WEIGHT_REGISTER["Default-Force-FP32"]("final_layer.linear.weight", "final_layer.linear.bias")
        self.final_layer_adaLN_modulation_1 = MM_WEIGHT_REGISTER["Default"]("final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias")
helloyongyang's avatar
helloyongyang committed
12
13
14
15
16
17

        self.weight_list = [
            self.final_layer_linear,
            self.final_layer_adaLN_modulation_1,
        ]

helloyongyang's avatar
helloyongyang committed
18
19
20
21
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.set_config(self.config["mm_config"])
                weight.load(weight_dict)
helloyongyang's avatar
helloyongyang committed
22
23

    def to_cpu(self):
helloyongyang's avatar
helloyongyang committed
24
25
26
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.to_cpu()
helloyongyang's avatar
helloyongyang committed
27
28

    def to_cuda(self):
helloyongyang's avatar
helloyongyang committed
29
30
31
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.to_cuda()