Unverified Commit 7761b89d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

update conversion script for Kandinsky unet (#3766)



* update kandinsky conversion script

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent ce550493
...@@ -8,7 +8,6 @@ from accelerate import load_checkpoint_and_dispatch ...@@ -8,7 +8,6 @@ from accelerate import load_checkpoint_and_dispatch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.prior_transformer import PriorTransformer from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.vq_model import VQModel from diffusers.models.vq_model import VQModel
from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel
""" """
...@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix ...@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
UNET_CONFIG = { UNET_CONFIG = {
"act_fn": "silu", "act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64, "attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536), "block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False, "center_input_sample": False,
"class_embed_type": "identity", "class_embed_type": None,
"class_embeddings_concat": False,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768, "cross_attention_dim": 768,
"down_block_types": ( "cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
), ],
"downsample_padding": 1, "downsample_padding": 1,
"dual_cross_attention": False, "dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True, "flip_sin_to_cos": True,
"freq_shift": 0, "freq_shift": 0,
"in_channels": 4, "in_channels": 4,
"layers_per_block": 3, "layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1, "mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05, "norm_eps": 1e-05,
"norm_num_groups": 32, "norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False, "only_cross_attention": False,
"out_channels": 8, "out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift", "resnet_time_scale_shift": "scale_shift",
"sample_size": 64, "sample_size": 64,
"up_block_types": ( "time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
), ],
"upcast_attention": False, "upcast_attention": False,
"use_linear_projection": False, "use_linear_projection": False,
} }
...@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): ...@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
# <original>.input_blocks -> <diffusers>.down_blocks # <original>.input_blocks -> <diffusers>.down_blocks
...@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): ...@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
INPAINT_UNET_CONFIG = { INPAINT_UNET_CONFIG = {
"act_fn": "silu", "act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64, "attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536), "block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False, "center_input_sample": False,
"class_embed_type": "identity", "class_embed_type": None,
"class_embeddings_concat": None,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768, "cross_attention_dim": 768,
"down_block_types": ( "cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D",
), ],
"downsample_padding": 1, "downsample_padding": 1,
"dual_cross_attention": False, "dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True, "flip_sin_to_cos": True,
"freq_shift": 0, "freq_shift": 0,
"in_channels": 9, "in_channels": 9,
"layers_per_block": 3, "layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1, "mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05, "norm_eps": 1e-05,
"norm_num_groups": 32, "norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False, "only_cross_attention": False,
"out_channels": 8, "out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift", "resnet_time_scale_shift": "scale_shift",
"sample_size": 64, "sample_size": 64,
"up_block_types": ( "time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D",
), ],
"upcast_attention": False, "upcast_attention": False,
"use_linear_projection": False, "use_linear_projection": False,
} }
...@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config(): ...@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {} diffusers_checkpoint = {}
num_head_channels = UNET_CONFIG["attention_head_dim"] num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
# <original>.input_blocks -> <diffusers>.down_blocks # <original>.input_blocks -> <diffusers>.down_blocks
...@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): ...@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
# done inpaint unet # done inpaint unet
# text proj
TEXT_PROJ_CONFIG = {}
def text_proj_from_original_config():
model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG)
return model
# Note that the input checkpoint is the original text2img model checkpoint
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
diffusers_checkpoint = {
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
"encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"],
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
"clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"],
"clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"],
# <original>.proj_n -> <diffusers>.embedding_proj
"embedding_proj.weight": checkpoint["proj_n.weight"],
"embedding_proj.bias": checkpoint["proj_n.bias"],
# <original>.ln_model_n -> <diffusers>.embedding_norm
"embedding_norm.weight": checkpoint["ln_model_n.weight"],
"embedding_norm.bias": checkpoint["ln_model_n.bias"],
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"],
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"],
}
return diffusers_checkpoint
# unet utils # unet utils
...@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint): ...@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
return diffusers_checkpoint return diffusers_checkpoint
def unet_add_embedding(checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(
{
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
}
)
return diffusers_checkpoint
def unet_encoder_hid_proj(checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(
{
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
}
)
return diffusers_checkpoint
# <original>.out.0 -> <diffusers>.conv_norm_out # <original>.out.0 -> <diffusers>.conv_norm_out
def unet_conv_norm_out(checkpoint): def unet_conv_norm_out(checkpoint):
diffusers_checkpoint = {} diffusers_checkpoint = {}
...@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location): ...@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
# text proj interlude
# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint)
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
del text2img_checkpoint del text2img_checkpoint
load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
print("done loading text2img") print("done loading text2img")
return unet_model, text_proj_model return unet_model
def inpaint_text2img(*, args, checkpoint_map_location): def inpaint_text2img(*, args, checkpoint_map_location):
...@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location): ...@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
inpaint_unet_model, inpaint_text2img_checkpoint inpaint_unet_model, inpaint_text2img_checkpoint
) )
# text proj interlude
# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint)
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
del inpaint_text2img_checkpoint del inpaint_text2img_checkpoint
load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
print("done loading inpaint text2img") print("done loading inpaint text2img")
return inpaint_unet_model, text_proj_model return inpaint_unet_model
# movq # movq
...@@ -1384,15 +1399,11 @@ if __name__ == "__main__": ...@@ -1384,15 +1399,11 @@ if __name__ == "__main__":
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path) prior_model.save_pretrained(args.dump_path)
elif args.debug == "text2img": elif args.debug == "text2img":
unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
unet_model.save_pretrained(f"{args.dump_path}/unet") unet_model.save_pretrained(f"{args.dump_path}/unet")
text_proj_model.save_pretrained(f"{args.dump_path}/text_proj")
elif args.debug == "inpaint_text2img": elif args.debug == "inpaint_text2img":
inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img( inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
args=args, checkpoint_map_location=checkpoint_map_location
)
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj")
elif args.debug == "decoder": elif args.debug == "decoder":
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
decoder.save_pretrained(f"{args.dump_path}/decoder") decoder.save_pretrained(f"{args.dump_path}/decoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment