convert_cogvideox_to_diffusers.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
from typing import Any, Dict

import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel


def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
    to_q_key = key.replace("query_key_value", "to_q")
    to_k_key = key.replace("query_key_value", "to_k")
    to_v_key = key.replace("query_key_value", "to_v")
    to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
    state_dict[to_q_key] = to_q
    state_dict[to_k_key] = to_k
    state_dict[to_v_key] = to_v
    state_dict.pop(key)


def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
    layer_id, weight_or_bias = key.split(".")[-2:]

    if "query" in key:
        new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
    elif "key" in key:
        new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"

    state_dict[new_key] = state_dict.pop(key)


def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
    layer_id, _, weight_or_bias = key.split(".")[-3:]

    weights_or_biases = state_dict[key].chunk(12, dim=0)
    norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
    norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])

    norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
    state_dict[norm1_key] = norm1_weights_or_biases

    norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
    state_dict[norm2_key] = norm2_weights_or_biases

    state_dict.pop(key)


def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
    state_dict.pop(key)


def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
    key_split = key.split(".")
    layer_index = int(key_split[2])
    replace_layer_index = 4 - 1 - layer_index

    key_split[1] = "up_blocks"
    key_split[2] = str(replace_layer_index)
    new_key = ".".join(key_split)

    state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
    "transformer.final_layernorm": "norm_final",
    "transformer": "transformer_blocks",
    "attention": "attn1",
    "mlp": "ff.net",
    "dense_h_to_4h": "0.proj",
    "dense_4h_to_h": "2",
    ".layers": "",
    "dense": "to_out.0",
    "input_layernorm": "norm1.norm",
    "post_attn1_layernorm": "norm2.norm",
    "time_embed.0": "time_embedding.linear_1",
    "time_embed.2": "time_embedding.linear_2",
    "mixins.patch_embed": "patch_embed",
    "mixins.final_layer.norm_final": "norm_out.norm",
    "mixins.final_layer.linear": "proj_out",
    "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
    "query_key_value": reassign_query_key_value_inplace,
    "query_layernorm_list": reassign_query_key_layernorm_inplace,
    "key_layernorm_list": reassign_query_key_layernorm_inplace,
    "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
    "embed_tokens": remove_keys_inplace,
zR's avatar
zR committed
89
90
91
    "freqs_sin": remove_keys_inplace,
    "freqs_cos": remove_keys_inplace,
    "position_embedding": remove_keys_inplace,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
}

VAE_KEYS_RENAME_DICT = {
    "block.": "resnets.",
    "down.": "down_blocks.",
    "downsample": "downsamplers.0",
    "upsample": "upsamplers.0",
    "nin_shortcut": "conv_shortcut",
    "encoder.mid.block_1": "encoder.mid_block.resnets.0",
    "encoder.mid.block_2": "encoder.mid_block.resnets.1",
    "decoder.mid.block_1": "decoder.mid_block.resnets.0",
    "decoder.mid.block_2": "decoder.mid_block.resnets.1",
}

VAE_SPECIAL_KEYS_REMAP = {
    "loss": remove_keys_inplace,
    "up.": replace_up_keys_inplace,
}

TOKENIZER_MAX_LENGTH = 226


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
    state_dict = saved_dict
    if "model" in saved_dict.keys():
        state_dict = state_dict["model"]
    if "module" in saved_dict.keys():
        state_dict = state_dict["module"]
    if "state_dict" in saved_dict.keys():
        state_dict = state_dict["state_dict"]
    return state_dict


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


zR's avatar
zR committed
129
130
131
132
133
134
135
def convert_transformer(
    ckpt_path: str,
    num_layers: int,
    num_attention_heads: int,
    use_rotary_positional_embeddings: bool,
    dtype: torch.dtype,
):
136
137
138
    PREFIX_KEY = "model.diffusion_model."

    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
zR's avatar
zR committed
139
140
141
142
143
    transformer = CogVideoXTransformer3DModel(
        num_layers=num_layers,
        num_attention_heads=num_attention_heads,
        use_rotary_positional_embeddings=use_rotary_positional_embeddings,
    ).to(dtype=dtype)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    for key in list(original_state_dict.keys()):
        new_key = key[len(PREFIX_KEY) :]
        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_inplace(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    transformer.load_state_dict(original_state_dict, strict=True)
    return transformer


zR's avatar
zR committed
161
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
162
    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
zR's avatar
zR committed
163
    vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    for key in list(original_state_dict.keys()):
        new_key = key[:]
        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_inplace(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    vae.load_state_dict(original_state_dict, strict=True)
    return vae


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
    )
    parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
    parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
zR's avatar
zR committed
188
189
    parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
    parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
190
191
192
193
194
195
    parser.add_argument(
        "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
    )
    parser.add_argument(
        "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
    )
zR's avatar
zR committed
196
197
198
199
200
201
202
203
204
205
206
207
    # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
    parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
    # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
    parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
    # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
    parser.add_argument(
        "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
    )
    # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
    parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
    # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
    parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
208
209
210
211
212
213
214
215
216
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    transformer = None
    vae = None

zR's avatar
zR committed
217
218
219
220
221
    if args.fp16 and args.bf16:
        raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")

    dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32

222
    if args.transformer_ckpt_path is not None:
zR's avatar
zR committed
223
224
225
226
227
228
229
        transformer = convert_transformer(
            args.transformer_ckpt_path,
            args.num_layers,
            args.num_attention_heads,
            args.use_rotary_positional_embeddings,
            dtype,
        )
230
    if args.vae_ckpt_path is not None:
zR's avatar
zR committed
231
        vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
232
233
234
235
236

    text_encoder_id = "google/t5-v1_1-xxl"
    tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
    text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

zR's avatar
zR committed
237
238
239
240
    # Apparently, the conversion does not work any more without this :shrug:
    for param in text_encoder.parameters():
        param.data = param.data.contiguous()

241
242
    scheduler = CogVideoXDDIMScheduler.from_config(
        {
zR's avatar
zR committed
243
            "snr_shift_scale": args.snr_shift_scale,
244
245
246
247
248
249
250
251
            "beta_end": 0.012,
            "beta_schedule": "scaled_linear",
            "beta_start": 0.00085,
            "clip_sample": False,
            "num_train_timesteps": 1000,
            "prediction_type": "v_prediction",
            "rescale_betas_zero_snr": True,
            "set_alpha_to_one": True,
zR's avatar
zR committed
252
            "timestep_spacing": "trailing",
253
254
255
256
257
258
259
260
261
        }
    )

    pipe = CogVideoXPipeline(
        tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
    )

    if args.fp16:
        pipe = pipe.to(dtype=torch.float16)
zR's avatar
zR committed
262
263
    if args.bf16:
        pipe = pipe.to(dtype=torch.bfloat16)
264

zR's avatar
zR committed
265
266
267
    # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
    # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
    # is either fp16/bf16 here).
268
    pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)