pre_weights.py 6.15 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER, EMBEDDING_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, TENSOR_REGISTER


class WanPreWeights(WeightModule):
    def __init__(self, config):
        super().__init__()
        self.in_dim = config["in_dim"]
        self.dim = config["dim"]
        self.patch_size = (1, 2, 2)
        self.config = config

        self.add_module(
            "patch_embedding",
            CONV3D_WEIGHT_REGISTER["Default"](
                "patch_embedding.weight",
                "patch_embedding.bias",
                stride=self.patch_size,
                lora_prefix="diffusion_model.patch_embedding",
            ),
        )

        if config["task"] in ["rs2v"]:
            self.add_module(
                "ref_patch_embedding",
                CONV3D_WEIGHT_REGISTER["Default"](
                    "ref_patch_embedding.weight",
                    "ref_patch_embedding.bias",
                    stride=self.patch_size,
                    lora_prefix="diffusion_model.ref_patch_embedding",
                ),
            )
            self.add_module(
                "prev_patch_embedding",
                CONV3D_WEIGHT_REGISTER["Default"](
                    "prev_patch_embedding.weight",
                    "prev_patch_embedding.bias",
                    stride=self.patch_size,
                    lora_prefix="diffusion_model.prev_patch_embedding",
                ),
            )
            self.add_module(
                "cont_patch_embedding",
                CONV3D_WEIGHT_REGISTER["Default"](
                    "cont_patch_embedding.weight",
                    "cont_patch_embedding.bias",
                    stride=self.patch_size,
                    lora_prefix="diffusion_model.cont_patch_embedding",
                ),
            )
            self.add_module(
                "state_embedding",
                EMBEDDING_WEIGHT_REGISTER["Default"](
                    "state_embedding.weight",
                ),
            )

        self.add_module(
            "text_embedding_0",
            MM_WEIGHT_REGISTER["Default"](
                "text_embedding.0.weight",
                "text_embedding.0.bias",
                lora_prefix="diffusion_model.text_embedding",
            ),
        )
        self.add_module(
            "text_embedding_2",
            MM_WEIGHT_REGISTER["Default"](
                "text_embedding.2.weight",
                "text_embedding.2.bias",
                lora_prefix="diffusion_model.text_embedding",
            ),
        )
        self.add_module(
            "time_embedding_0",
            MM_WEIGHT_REGISTER["Default"](
                "time_embedding.0.weight",
                "time_embedding.0.bias",
                lora_prefix="diffusion_model.time_embedding",
            ),
        )
        self.add_module(
            "time_embedding_2",
            MM_WEIGHT_REGISTER["Default"](
                "time_embedding.2.weight",
                "time_embedding.2.bias",
                lora_prefix="diffusion_model.time_embedding",
            ),
        )
        self.add_module(
            "time_projection_1",
            MM_WEIGHT_REGISTER["Default"](
                "time_projection.1.weight",
                "time_projection.1.bias",
                lora_prefix="diffusion_model.time_projection",
            ),
        )

        if config["task"] in ["i2v", "flf2v", "animate", "s2v", "rs2v"] and config.get("use_image_encoder", True):
            self.add_module(
                "proj_0",
                LN_WEIGHT_REGISTER["torch"](
                    "img_emb.proj.0.weight",
                    "img_emb.proj.0.bias",
                    lora_prefix="diffusion_model.img_emb",
                ),
            )
            self.add_module(
                "proj_1",
                MM_WEIGHT_REGISTER["Default"](
                    "img_emb.proj.1.weight",
                    "img_emb.proj.1.bias",
                    lora_prefix="diffusion_model.img_emb",
                ),
            )
            self.add_module(
                "proj_3",
                MM_WEIGHT_REGISTER["Default"](
                    "img_emb.proj.3.weight",
                    "img_emb.proj.3.bias",
                    lora_prefix="diffusion_model.img_emb",
                ),
            )
            self.add_module(
                "proj_4",
                LN_WEIGHT_REGISTER["torch"](
                    "img_emb.proj.4.weight",
                    "img_emb.proj.4.bias",
                    lora_prefix="diffusion_model.img_emb",
                ),
            )

        if config["model_cls"] == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
            self.add_module(
                "cfg_cond_proj_1",
                MM_WEIGHT_REGISTER["Default"](
                    "guidance_embedding.linear_1.weight",
                    "guidance_embedding.linear_1.bias",
                ),
            )
            self.add_module(
                "cfg_cond_proj_2",
                MM_WEIGHT_REGISTER["Default"](
                    "guidance_embedding.linear_2.weight",
                    "guidance_embedding.linear_2.bias",
                ),
            )

        if config["model_cls"] == "wan2.1_mean_flow_distill":
            self.add_module(
                "time_embedding_r_0",
                MM_WEIGHT_REGISTER["Default"]("time_embedding_r.0.weight", "time_embedding_r.0.bias"),
            )
            self.add_module(
                "time_embedding_r_2",
                MM_WEIGHT_REGISTER["Default"]("time_embedding_r.2.weight", "time_embedding_r.2.bias"),
            )

        if config["task"] == "flf2v" and config.get("use_image_encoder", True):
            self.add_module(
                "emb_pos",
                TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
            )
        if config["task"] == "animate":
            self.add_module(
                "pose_patch_embedding",
                CONV3D_WEIGHT_REGISTER["Default"](
                    "pose_patch_embedding.weight",
                    "pose_patch_embedding.bias",
                    stride=self.patch_size,
                ),
            )