pre_weights.py 2.85 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
TorynCurtis's avatar
TorynCurtis committed
2
3
4
5
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.conv.conv3d import Conv3dWeightTemplate
helloyongyang's avatar
helloyongyang committed
6

Dongz's avatar
Dongz committed
7

helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
class WanPreWeights:
    def __init__(self, config):
        self.in_dim = config["in_dim"]
        self.dim = config["dim"]
        self.patch_size = (1, 2, 2)
TorynCurtis's avatar
TorynCurtis committed
13
        self.config = config
helloyongyang's avatar
helloyongyang committed
14
15

    def load_weights(self, weight_dict):
Dongz's avatar
Dongz committed
16
        self.patch_embedding = CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size)
TorynCurtis's avatar
TorynCurtis committed
17

Dongz's avatar
Dongz committed
18
19
20
21
22
        self.text_embedding_0 = MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias")
        self.text_embedding_2 = MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias")
        self.time_embedding_0 = MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias")
        self.time_embedding_2 = MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias")
        self.time_projection_1 = MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias")
TorynCurtis's avatar
TorynCurtis committed
23
24
25
26
27
28
29
30
31

        self.weight_list = [
            self.patch_embedding,
            self.text_embedding_0,
            self.text_embedding_2,
            self.time_embedding_0,
            self.time_embedding_2,
            self.time_projection_1,
        ]
helloyongyang's avatar
helloyongyang committed
32

Dongz's avatar
Dongz committed
33
34
35
36
37
        if "img_emb.proj.0.weight" in weight_dict.keys():
            self.proj_0 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5)
            self.proj_1 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias")
            self.proj_3 = MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias")
            self.proj_4 = LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5)
TorynCurtis's avatar
TorynCurtis committed
38
39
40
41
42
            self.weight_list.append(self.proj_0)
            self.weight_list.append(self.proj_1)
            self.weight_list.append(self.proj_3)
            self.weight_list.append(self.proj_4)

helloyongyang's avatar
helloyongyang committed
43
44
45
46
        for weight in self.weight_list:
            if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
                weight.set_config(self.config["mm_config"])
                weight.load(weight_dict)
gushiqiao's avatar
gushiqiao committed
47
                if self.config["cpu_offload"]:
helloyongyang's avatar
helloyongyang committed
48
                    weight.to_cpu()
TorynCurtis's avatar
TorynCurtis committed
49
50

    def to_cpu(self):
helloyongyang's avatar
helloyongyang committed
51
52
53
        for weight in self.weight_list:
            if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
                weight.to_cpu()
TorynCurtis's avatar
TorynCurtis committed
54
55

    def to_cuda(self):
helloyongyang's avatar
helloyongyang committed
56
57
58
        for weight in self.weight_list:
            if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
                weight.to_cuda()