Unverified Commit edc154da authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Update Ruff to latest Version (#10919)

* update

* update

* update

* update
parent 552cd320
...@@ -468,7 +468,7 @@ def make_vqvae(old_vae): ...@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all() # assert (old_output == new_output).all()
print("skipping full vae equivalence check") print("skipping full vae equivalence check")
print(f"vae full diff { (old_output - new_output).float().abs().sum()}") print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
return new_vae return new_vae
......
...@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): ...@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1: if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0" new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.1" old_prefix = f"output_blocks.{current_layer - 1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D": elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1): for j in range(layers_per_block + 1):
...@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config): ...@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1: if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0" new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.2" old_prefix = f"output_blocks.{current_layer - 1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix) new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"] new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
......
...@@ -261,9 +261,9 @@ def main(args): ...@@ -261,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0] model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path): if not os.path.isfile(args.model_path):
assert ( assert model_name == args.model_path, (
model_name == args.model_path f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" )
args.model_path = download(model_name) args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"] sample_rate = MODELS_MAP[model_name]["sample_rate"]
...@@ -290,9 +290,9 @@ def main(args): ...@@ -290,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items(): for key, value in renamed_state_dict.items():
assert ( assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" )
if key == "time_proj.weight": if key == "time_proj.weight":
value = value.squeeze() value = value.squeeze()
......
...@@ -52,18 +52,18 @@ for i in range(3): ...@@ -52,18 +52,18 @@ for i in range(3):
for j in range(2): for j in range(2):
# loop over resnets/attentions for downblocks # loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0: if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4): for j in range(4):
# loop over resnets/attentions for upblocks # loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2: if i < 2:
...@@ -75,12 +75,12 @@ for i in range(3): ...@@ -75,12 +75,12 @@ for i in range(3):
if i < 3: if i < 3:
# no downsample in down_blocks.3 # no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3 # no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv.")) unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
...@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1." ...@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2): for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}." hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}." sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
...@@ -137,20 +137,20 @@ for i in range(4): ...@@ -137,20 +137,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample." sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets # up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd # also, up blocks in hf are numbered in reverse from sd
for j in range(3): for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder # this part accounts for mid blocks in both the encoder and the decoder
for i in range(2): for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}." hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
......
...@@ -47,36 +47,36 @@ for i in range(4): ...@@ -47,36 +47,36 @@ for i in range(4):
for j in range(2): for j in range(2):
# loop over resnets/attentions for downblocks # loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3: if i < 3:
# no attention layers in down_blocks.3 # no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3): for j in range(3):
# loop over resnets/attentions for upblocks # loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0: if i > 0:
# no attention layers in up_blocks.0 # no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3: if i < 3:
# no downsample in down_blocks.3 # no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3 # no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0." hf_mid_atn_prefix = "mid_block.attentions.0."
...@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) ...@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2): for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}." hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}." sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
...@@ -133,20 +133,20 @@ for i in range(4): ...@@ -133,20 +133,20 @@ for i in range(4):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample." sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets # up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd # also, up blocks in hf are numbered in reverse from sd
for j in range(3): for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder # this part accounts for mid blocks in both the encoder and the decoder
for i in range(2): for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}." hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
......
...@@ -21,9 +21,9 @@ def main(args): ...@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config( model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer" "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
) )
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
print(model_config) print(model_config)
for key in state_dict: for key in state_dict:
......
...@@ -13,15 +13,14 @@ def main(args): ...@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key] state_dict = state_dict[args.load_key]
except KeyError: except KeyError:
raise KeyError( raise KeyError(
f"{args.load_key} not found in the checkpoint." f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
f"Please load from the following keys:{state_dict.keys()}"
) )
device = "cuda" device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer") model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
model_config[ model_config["use_style_cond_and_image_meta_size"] = (
"use_style_cond_and_image_meta_size" args.use_style_cond_and_image_meta_size
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim # input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict: for key in state_dict:
......
...@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type): ...@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}" diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2 idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}" self_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_prefix = f"{block_prefix}.{idx }" cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_index = 1 if not attention.add_self_attention else 2 cross_attention_index = 1 if not attention.add_self_attention else 2
idx = ( idx = (
n * attention_idx + cross_attention_index n * attention_idx + cross_attention_index
if block_type == "up" if block_type == "up"
else n * attention_idx + cross_attention_index + 1 else n * attention_idx + cross_attention_index + 1
) )
cross_attention_prefix = f"{block_prefix}.{idx }" cross_attention_prefix = f"{block_prefix}.{idx}"
diffusers_checkpoint.update( diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint( cross_attn_to_diffusers_checkpoint(
...@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config): ...@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"] block_out_channels = original_config["channels"]
assert ( assert len(set(original_config["depths"])) == 1, (
len(set(original_config["depths"])) == 1 "UNet2DConditionModel currently do not support blocks with different number of layers"
), "UNet2DConditionModel currently do not support blocks with different number of layers" )
layers_per_block = original_config["depths"][0] layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"] class_labels_dim = original_config["mapping_cond_dim"]
......
...@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa ...@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D) # Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3 for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.weight" f"blocks.0.{i + 1}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.0.bias" f"blocks.0.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.weight" f"blocks.0.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.2.bias" f"blocks.0.{i + 1}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.weight" f"blocks.0.{i + 1}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.3.bias" f"blocks.0.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.weight" f"blocks.0.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.0.{i+1}.stack.5.bias" f"blocks.0.{i + 1}.stack.5.bias"
) )
# Convert up_blocks (MochiUpBlock3D) # Convert up_blocks (MochiUpBlock3D)
...@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa ...@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3): for block in range(3):
for i in range(down_block_layers[block]): for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.weight" f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.0.bias" f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.weight" f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.2.bias" f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.weight" f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.3.bias" f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.weight" f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
) )
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
f"blocks.{block+1}.blocks.{i}.stack.5.bias" f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
f"blocks.{block+1}.proj.weight" f"blocks.{block + 1}.proj.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
f"blocks.{block + 1}.proj.bias"
) )
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D) # Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3 for i in range(3): # layers_per_block[0] = 3
...@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa ...@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D) # Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3 for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.weight" f"layers.{i + 1}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.0.bias" f"layers.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.weight" f"layers.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.2.bias" f"layers.{i + 1}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.weight" f"layers.{i + 1}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.3.bias" f"layers.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.weight" f"layers.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+1}.stack.5.bias" f"layers.{i + 1}.stack.5.bias"
) )
# Convert down_blocks (MochiDownBlock3D) # Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3] down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3): for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.weight" f"layers.{block + 4}.layers.0.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.0.bias" f"layers.{block + 4}.layers.0.bias"
) )
for i in range(down_block_layers[block]): for i in range(down_block_layers[block]):
# Convert resnets # Convert resnets
new_state_dict[ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.0.bias" f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.weight" f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.2.bias" f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
) )
new_state_dict[
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.3.bias" f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.weight" f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.stack.5.bias" f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
) )
# Convert attentions # Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0) q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
) )
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
) )
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
) )
# Convert block_out (MochiMidBlock3D) # Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3 for i in range(3): # layers_per_block[-1] = 3
# Convert resnets # Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.weight" f"layers.{i + 7}.stack.0.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.0.bias" f"layers.{i + 7}.stack.0.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.weight" f"layers.{i + 7}.stack.2.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.2.bias" f"layers.{i + 7}.stack.2.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.weight" f"layers.{i + 7}.stack.3.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.3.bias" f"layers.{i + 7}.stack.3.bias"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.weight" f"layers.{i + 7}.stack.5.weight"
) )
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.stack.5.bias" f"layers.{i + 7}.stack.5.bias"
) )
# Convert attentions # Convert attentions
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0) q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.weight" f"layers.{i + 7}.attn_block.attn.out.weight"
) )
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.attn.out.bias" f"layers.{i + 7}.attn_block.attn.out.bias"
) )
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.weight" f"layers.{i + 7}.attn_block.norm.weight"
) )
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
f"layers.{i+7}.attn_block.norm.bias" f"layers.{i + 7}.attn_block.norm.bias"
) )
# Convert output layers # Convert output layers
......
...@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint): ...@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
......
...@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint): ...@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
......
...@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint): ...@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list # replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1) sequential_layer = re.match(sequential_layers_pattern, key).group(1)
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key): elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1)) projecton_layer = int(re.match(text_projection_pattern, key).group(1))
......
...@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ...@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer # get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0]) idx = int(new_key.split("coder.layers.")[1].split(".")[0])
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
if "encoder" in new_key: if "encoder" in new_key:
for i in range(3): for i in range(3):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
else: else:
for i in range(2, 5): for i in range(2, 5):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta") new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha") new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
...@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ...@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_") new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1: if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx-1}", "snake1") new_key = new_key.replace(f"block.{idx - 1}", "snake1")
elif idx == num_autoencoder_layers + 2: elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx-1}", "conv2") new_key = new_key.replace(f"block.{idx - 1}", "conv2")
else: else:
new_key = new_key new_key = new_key
......
...@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint( ...@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor # TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] )
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
...@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint( ...@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
) )
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values(): if ["conv.bias", "conv.weight"] in output_block_list.values():
......
...@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV ...@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
def vqvae_model_from_original_config(original_config): def vqvae_model_from_original_config(original_config):
assert ( assert original_config["target"] in PORTED_VQVAES, (
original_config["target"] in PORTED_VQVAES f"{original_config['target']} has not yet been ported to diffusers."
), f"{original_config['target']} has not yet been ported to diffusers." )
original_config = original_config["params"] original_config = original_config["params"]
...@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima ...@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
def transformer_model_from_original_config( def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config original_diffusion_config, original_transformer_config, original_content_embedding_config
): ):
assert ( assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
original_diffusion_config["target"] in PORTED_DIFFUSIONS f"{original_diffusion_config['target']} has not yet been ported to diffusers."
), f"{original_diffusion_config['target']} has not yet been ported to diffusers." )
assert ( assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
original_transformer_config["target"] in PORTED_TRANSFORMERS f"{original_transformer_config['target']} has not yet been ported to diffusers."
), f"{original_transformer_config['target']} has not yet been ported to diffusers." )
assert ( assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers." )
original_diffusion_config = original_diffusion_config["params"] original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"] original_transformer_config = original_transformer_config["params"]
......
...@@ -122,7 +122,7 @@ _deps = [ ...@@ -122,7 +122,7 @@ _deps = [
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
"python>=3.8.0", "python>=3.8.0",
"ruff==0.1.5", "ruff==0.9.10",
"safetensors>=0.3.1", "safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19", "GitPython<3.1.19",
......
...@@ -29,7 +29,7 @@ deps = { ...@@ -29,7 +29,7 @@ deps = {
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0", "python": "python>=3.8.0",
"ruff": "ruff==0.1.5", "ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
......
...@@ -295,8 +295,7 @@ class IPAdapterMixin: ...@@ -295,8 +295,7 @@ class IPAdapterMixin:
): ):
if len(scale_configs) != len(attn_processor.scale): if len(scale_configs) != len(attn_processor.scale):
raise ValueError( raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to " f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
f"{len(attn_processor.scale)} IP-Adapter."
) )
elif len(scale_configs) == 1: elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale) scale_configs = scale_configs * len(attn_processor.scale)
......
...@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -184,9 +184,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present. # Store DoRA scale if present.
if dora_present_in_unet: if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Handle text encoder LoRAs. # Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
...@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ ...@@ -206,13 +206,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
if lora_name.startswith(("lora_te_", "lora_te1_")): if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
elif lora_name.startswith("lora_te2_"): elif lora_name.startswith("lora_te2_"):
te2_state_dict[ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) )
# Store alpha if present. # Store alpha if present.
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
...@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): ...@@ -1020,21 +1020,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
for lora_key in ["lora_A", "lora_B"]: for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in ## time_text_embed.timestep_embedder <- time_in
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight") )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight") )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias") )
## time_text_embed.text_embedder <- vector_in ## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
...@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): ...@@ -1056,21 +1056,21 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
# guidance # guidance
has_guidance = any("guidance" in k for k in original_state_dict) has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance: if has_guidance:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight") )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias" original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias") )
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight") )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys: if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias" original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias") )
# context_embedder # context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop( converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
......
...@@ -26,6 +26,7 @@ _import_structure = {} ...@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
...@@ -41,7 +42,6 @@ if is_torch_available(): ...@@ -41,7 +42,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["cache_utils"] = ["CacheMixin"] _import_structure["cache_utils"] = ["CacheMixin"]
_import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment