"docker/Dockerfile" did not exist on "6658b8c95dbf163582f6043610e4bb813156e272"
post_weights.py 1.18 KB
Newer Older
TorynCurtis's avatar
TorynCurtis committed
1
2
3
4
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate


helloyongyang's avatar
helloyongyang committed
5
class WanPostWeights:
TorynCurtis's avatar
TorynCurtis committed
6
7
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
8
9

    def load_weights(self, weight_dict):
Dongz's avatar
Dongz committed
10
11
        self.head = MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")
        self.head_modulation = weight_dict["head.modulation"]
TorynCurtis's avatar
TorynCurtis committed
12

gushiqiao's avatar
gushiqiao committed
13
        self.weight_list = [self.head]
TorynCurtis's avatar
TorynCurtis committed
14

helloyongyang's avatar
helloyongyang committed
15
16
17
18
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.set_config(self.config["mm_config"])
                weight.load(weight_dict)
gushiqiao's avatar
gushiqiao committed
19
                if self.config["cpu_offload"]:
helloyongyang's avatar
helloyongyang committed
20
                    weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
21
                    self.head_modulation = self.head_modulation.cpu()
TorynCurtis's avatar
TorynCurtis committed
22
23

    def to_cpu(self):
helloyongyang's avatar
helloyongyang committed
24
25
26
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.to_cpu()
gushiqiao's avatar
gushiqiao committed
27
        self.head_modulation = self.head_modulation.cpu()
TorynCurtis's avatar
TorynCurtis committed
28
29

    def to_cuda(self):
helloyongyang's avatar
helloyongyang committed
30
31
32
        for weight in self.weight_list:
            if isinstance(weight, MMWeightTemplate):
                weight.to_cuda()
gushiqiao's avatar
gushiqiao committed
33
        self.head_modulation = self.head_modulation.cuda()