pre_weights.py 8.23 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
    MM_WEIGHT_REGISTER,
)


class LTX2PreWeights(WeightModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Video weights
        self.add_module(
            "patchify_proj",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.patchify_proj.weight",
                "model.diffusion_model.patchify_proj.bias",
            ),
        )

        self.add_module(
            "adaln_single_emb_timestep_embedder_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "adaln_single_emb_timestep_embedder_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )

        self.add_module(
            "adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.adaln_single.linear.weight",
                "model.diffusion_model.adaln_single.linear.bias",
            ),
        )

        self.add_module(
            "caption_projection_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.caption_projection.linear_1.weight",
                "model.diffusion_model.caption_projection.linear_1.bias",
                lora_prefix="diffusion_model.caption_projection",
            ),
        )
        self.add_module(
            "caption_projection_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.caption_projection.linear_2.weight",
                "model.diffusion_model.caption_projection.linear_2.bias",
                lora_prefix="diffusion_model.caption_projection",
            ),
        )

        # Audio weights
        self.add_module(
            "audio_patchify_proj",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_patchify_proj.weight",
                "model.diffusion_model.audio_patchify_proj.bias",
            ),
        )

        self.add_module(
            "audio_adaln_single_emb_timestep_embedder_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "audio_adaln_single_emb_timestep_embedder_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.audio_adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )
        self.add_module(
            "audio_adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_adaln_single.linear.weight",
                "model.diffusion_model.audio_adaln_single.linear.bias",
            ),
        )

        self.add_module(
            "audio_caption_projection_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_caption_projection.linear_1.weight",
                "model.diffusion_model.audio_caption_projection.linear_1.bias",
            ),
        )
        self.add_module(
            "audio_caption_projection_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_caption_projection.linear_2.weight",
                "model.diffusion_model.audio_caption_projection.linear_2.bias",
            ),
        )

        self.add_module(
            "av_ca_video_scale_shift_adaln_single_emb_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "av_ca_video_scale_shift_adaln_single_emb_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )
        self.add_module(
            "av_ca_video_scale_shift_adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.linear.weight",
                "model.diffusion_model.av_ca_video_scale_shift_adaln_single.linear.bias",
            ),
        )

        # AV CA Audio scale-shift AdaLN
        self.add_module(
            "av_ca_audio_scale_shift_adaln_single_emb_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "av_ca_audio_scale_shift_adaln_single_emb_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )
        self.add_module(
            "av_ca_audio_scale_shift_adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.linear.weight",
                "model.diffusion_model.av_ca_audio_scale_shift_adaln_single.linear.bias",
            ),
        )

        # AV CA A2V gate AdaLN
        self.add_module(
            "av_ca_a2v_gate_adaln_single_emb_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "av_ca_a2v_gate_adaln_single_emb_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )
        self.add_module(
            "av_ca_a2v_gate_adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.linear.weight",
                "model.diffusion_model.av_ca_a2v_gate_adaln_single.linear.bias",
            ),
        )

        # AV CA V2A gate AdaLN
        self.add_module(
            "av_ca_v2a_gate_adaln_single_emb_linear_1",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_1.bias",
            ),
        )
        self.add_module(
            "av_ca_v2a_gate_adaln_single_emb_linear_2",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2.weight",
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.emb.timestep_embedder.linear_2.bias",
            ),
        )
        self.add_module(
            "av_ca_v2a_gate_adaln_single_linear",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.linear.weight",
                "model.diffusion_model.av_ca_v2a_gate_adaln_single.linear.bias",
            ),
        )