convert_hunyuan_video_to_diffusers.py 13 KB
Newer Older
Aryan's avatar
Aryan committed
1
2
3
4
5
import argparse
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
Aryan's avatar
Aryan committed
6
7
8
9
10
11
12
13
from transformers import (
    AutoModel,
    AutoTokenizer,
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPTokenizer,
    LlavaForConditionalGeneration,
)
Aryan's avatar
Aryan committed
14
15
16
17

from diffusers import (
    AutoencoderKLHunyuanVideo,
    FlowMatchEulerDiscreteScheduler,
Aryan's avatar
Aryan committed
18
    HunyuanVideoImageToVideoPipeline,
Aryan's avatar
Aryan committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    HunyuanVideoPipeline,
    HunyuanVideoTransformer3DModel,
)


def remap_norm_scale_shift_(key, state_dict):
    weight = state_dict.pop(key)
    shift, scale = weight.chunk(2, dim=0)
    new_weight = torch.cat([scale, shift], dim=0)
    state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight


def remap_txt_in_(key, state_dict):
    def rename_key(key):
        new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
        new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
        new_key = new_key.replace("txt_in", "context_embedder")
        new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
        new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
        new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
        new_key = new_key.replace("mlp", "ff")
        return new_key

    if "self_attn_qkv" in key:
        weight = state_dict.pop(key)
        to_q, to_k, to_v = weight.chunk(3, dim=0)
        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
        state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
    else:
        state_dict[rename_key(key)] = state_dict.pop(key)


def remap_img_attn_qkv_(key, state_dict):
    weight = state_dict.pop(key)
    to_q, to_k, to_v = weight.chunk(3, dim=0)
    state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
    state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
    state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v


def remap_txt_attn_qkv_(key, state_dict):
    weight = state_dict.pop(key)
    to_q, to_k, to_v = weight.chunk(3, dim=0)
    state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
    state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
    state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v


def remap_single_transformer_blocks_(key, state_dict):
    hidden_size = 3072

    if "linear1.weight" in key:
        linear1_weight = state_dict.pop(key)
        split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
        q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
        new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
        state_dict[f"{new_key}.attn.to_q.weight"] = q
        state_dict[f"{new_key}.attn.to_k.weight"] = k
        state_dict[f"{new_key}.attn.to_v.weight"] = v
        state_dict[f"{new_key}.proj_mlp.weight"] = mlp

    elif "linear1.bias" in key:
        linear1_bias = state_dict.pop(key)
        split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
        q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
        new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
        state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
        state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
        state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
        state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias

    else:
        new_key = key.replace("single_blocks", "single_transformer_blocks")
        new_key = new_key.replace("linear2", "proj_out")
        new_key = new_key.replace("q_norm", "attn.norm_q")
        new_key = new_key.replace("k_norm", "attn.norm_k")
        state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
    "img_in": "x_embedder",
    "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
    "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
    "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
    "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
    "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
    "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
    "double_blocks": "transformer_blocks",
    "img_attn_q_norm": "attn.norm_q",
    "img_attn_k_norm": "attn.norm_k",
    "img_attn_proj": "attn.to_out.0",
    "txt_attn_q_norm": "attn.norm_added_q",
    "txt_attn_k_norm": "attn.norm_added_k",
    "txt_attn_proj": "attn.to_add_out",
    "img_mod.linear": "norm1.linear",
    "img_norm1": "norm1.norm",
    "img_norm2": "norm2",
    "img_mlp": "ff",
    "txt_mod.linear": "norm1_context.linear",
    "txt_norm1": "norm1.norm",
    "txt_norm2": "norm2_context",
    "txt_mlp": "ff_context",
    "self_attn_proj": "attn.to_out.0",
    "modulation.linear": "norm.linear",
    "pre_norm": "norm.norm",
    "final_layer.norm_final": "norm_out.norm",
    "final_layer.linear": "proj_out",
    "fc1": "net.0.proj",
    "fc2": "net.2",
    "input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
    "txt_in": remap_txt_in_,
    "img_attn_qkv": remap_img_attn_qkv_,
    "txt_attn_qkv": remap_txt_attn_qkv_,
    "single_blocks": remap_single_transformer_blocks_,
    "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

VAE_KEYS_RENAME_DICT = {}

VAE_SPECIAL_KEYS_REMAP = {}


Aryan's avatar
Aryan committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
TRANSFORMER_CONFIGS = {
    "HYVideo-T/2-cfgdistill": {
        "in_channels": 16,
        "out_channels": 16,
        "num_attention_heads": 24,
        "attention_head_dim": 128,
        "num_layers": 20,
        "num_single_layers": 40,
        "num_refiner_layers": 2,
        "mlp_ratio": 4.0,
        "patch_size": 2,
        "patch_size_t": 1,
        "qk_norm": "rms_norm",
        "guidance_embeds": True,
        "text_embed_dim": 4096,
        "pooled_projection_dim": 768,
        "rope_theta": 256.0,
        "rope_axes_dim": (16, 56, 56),
    },
    "HYVideo-T/2-I2V": {
        "in_channels": 16 * 2 + 1,
        "out_channels": 16,
        "num_attention_heads": 24,
        "attention_head_dim": 128,
        "num_layers": 20,
        "num_single_layers": 40,
        "num_refiner_layers": 2,
        "mlp_ratio": 4.0,
        "patch_size": 2,
        "patch_size_t": 1,
        "qk_norm": "rms_norm",
        "guidance_embeds": False,
        "text_embed_dim": 4096,
        "pooled_projection_dim": 768,
        "rope_theta": 256.0,
        "rope_axes_dim": (16, 56, 56),
    },
}


Aryan's avatar
Aryan committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
    state_dict = saved_dict
    if "model" in saved_dict.keys():
        state_dict = state_dict["model"]
    if "module" in saved_dict.keys():
        state_dict = state_dict["module"]
    if "state_dict" in saved_dict.keys():
        state_dict = state_dict["state_dict"]
    return state_dict


Aryan's avatar
Aryan committed
200
def convert_transformer(ckpt_path: str, transformer_type: str):
Aryan's avatar
Aryan committed
201
    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
Aryan's avatar
Aryan committed
202
    config = TRANSFORMER_CONFIGS[transformer_type]
Aryan's avatar
Aryan committed
203
204

    with init_empty_weights():
Aryan's avatar
Aryan committed
205
        transformer = HunyuanVideoTransformer3DModel(**config)
Aryan's avatar
Aryan committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    for key in list(original_state_dict.keys()):
        new_key = key[:]
        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    transformer.load_state_dict(original_state_dict, strict=True, assign=True)
    return transformer


def convert_vae(ckpt_path: str):
    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

    with init_empty_weights():
        vae = AutoencoderKLHunyuanVideo()

    for key in list(original_state_dict.keys()):
        new_key = key[:]
        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    vae.load_state_dict(original_state_dict, strict=True, assign=True)
    return vae


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
    )
    parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
    parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
    parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
    parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
    parser.add_argument("--save_pipeline", action="store_true")
    parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
    parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
Aryan's avatar
Aryan committed
257
258
259
260
    parser.add_argument(
        "--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
    )
    parser.add_argument("--flow_shift", type=float, default=7.0)
Aryan's avatar
Aryan committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    return parser.parse_args()


DTYPE_MAPPING = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
}


if __name__ == "__main__":
    args = get_args()

    transformer = None
    dtype = DTYPE_MAPPING[args.dtype]

    if args.save_pipeline:
        assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
        assert args.text_encoder_path is not None
        assert args.tokenizer_path is not None
        assert args.text_encoder_2_path is not None

    if args.transformer_ckpt_path is not None:
Aryan's avatar
Aryan committed
284
        transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
Aryan's avatar
Aryan committed
285
286
287
288
289
290
291
292
293
294
        transformer = transformer.to(dtype=dtype)
        if not args.save_pipeline:
            transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

    if args.vae_ckpt_path is not None:
        vae = convert_vae(args.vae_ckpt_path)
        if not args.save_pipeline:
            vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

    if args.save_pipeline:
Aryan's avatar
Aryan committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        if args.transformer_type == "HYVideo-T/2-cfgdistill":
            text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
            tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
            text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
            tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
            scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)

            pipe = HunyuanVideoPipeline(
                transformer=transformer,
                vae=vae,
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                text_encoder_2=text_encoder_2,
                tokenizer_2=tokenizer_2,
                scheduler=scheduler,
            )
            pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
        else:
            text_encoder = LlavaForConditionalGeneration.from_pretrained(
                args.text_encoder_path, torch_dtype=torch.float16
            )
            tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
            text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
            tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
            scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
            image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)

            pipe = HunyuanVideoImageToVideoPipeline(
                transformer=transformer,
                vae=vae,
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                text_encoder_2=text_encoder_2,
                tokenizer_2=tokenizer_2,
                scheduler=scheduler,
                image_processor=image_processor,
            )
            pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")