post_weights.py 1.21 KB
Newer Older
TorynCurtis's avatar
TorynCurtis committed
1
2
3
4
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate


helloyongyang's avatar
helloyongyang committed
5
class WanPostWeights:
TorynCurtis's avatar
TorynCurtis committed
6
7
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
8
9

    def load_weights(self, weight_dict):
Dongz's avatar
Dongz committed
10
11
        self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
        self.head_modulation = weight_dict["head.modulation"]
TorynCurtis's avatar
TorynCurtis committed
12

gushiqiao's avatar
gushiqiao committed
13
        self.weight_list = [self.head]
TorynCurtis's avatar
TorynCurtis committed
14
15
16

        for mm_weight in self.weight_list:
            if isinstance(mm_weight, MMWeightTemplate):
Dongz's avatar
Dongz committed
17
                mm_weight.set_config(self.config["mm_config"])
TorynCurtis's avatar
TorynCurtis committed
18
                mm_weight.load(weight_dict)
gushiqiao's avatar
gushiqiao committed
19
20
21
                if self.config["cpu_offload"]:
                    mm_weight.to_cpu()
                    self.head_modulation = self.head_modulation.cpu()
TorynCurtis's avatar
TorynCurtis committed
22
23
24
25
26

    def to_cpu(self):
        for mm_weight in self.weight_list:
            if isinstance(mm_weight, MMWeightTemplate):
                mm_weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
27
        self.head_modulation = self.head_modulation.cpu()
TorynCurtis's avatar
TorynCurtis committed
28
29
30
31
32

    def to_cuda(self):
        for mm_weight in self.weight_list:
            if isinstance(mm_weight, MMWeightTemplate):
                mm_weight.to_cuda()
gushiqiao's avatar
gushiqiao committed
33
        self.head_modulation = self.head_modulation.cuda()