Commit a3dffc44 authored by comfyanonymous's avatar comfyanonymous
Browse files

Support AuraFlow Lora and loading model weights in diffusers format.

You can load model weights in diffusers format using the UNETLoader node.
parent ce2473bb
......@@ -274,4 +274,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
return key_map
......@@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix):
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
unet_config["n_layers"] = double_layers + single_layers
return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
......@@ -450,11 +454,19 @@ def model_config_from_diffusers_unet(state_dict):
return None
def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
if num_blocks > 0:
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
return None
for k in sd_map:
weight = state_dict.get(k, None)
if weight is not None:
......
......@@ -562,14 +562,13 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
if model_config is not None:
new_sd = sd
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is None:
return None
if new_sd is not None: #diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None:
return None
else: #diffusers
else: #diffusers unet
model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
......
......@@ -332,6 +332,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
n_layers = mmdit_config.get("n_layers", 0)
key_map = {}
for i in range(n_layers):
if i < n_double_layers:
index = i
prefix_from = "joint_transformer_blocks"
prefix_to = "{}double_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w2q.weight",
"attn.to_k.weight": "attn.w2k.weight",
"attn.to_v.weight": "attn.w2v.weight",
"attn.to_out.0.weight": "attn.w2o.weight",
"attn.add_q_proj.weight": "attn.w1q.weight",
"attn.add_k_proj.weight": "attn.w1k.weight",
"attn.add_v_proj.weight": "attn.w1v.weight",
"attn.to_add_out.weight": "attn.w1o.weight",
"ff.linear_1.weight": "mlpX.c_fc1.weight",
"ff.linear_2.weight": "mlpX.c_fc2.weight",
"ff.out_projection.weight": "mlpX.c_proj.weight",
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
"norm1.linear.weight": "modX.1.weight",
"norm1_context.linear.weight": "modC.1.weight",
}
else:
index = i - n_double_layers
prefix_from = "single_transformer_blocks"
prefix_to = "{}single_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w1q.weight",
"attn.to_k.weight": "attn.w1k.weight",
"attn.to_v.weight": "attn.w1v.weight",
"attn.to_out.0.weight": "attn.w1o.weight",
"norm1.linear.weight": "modCX.1.weight",
"ff.linear_1.weight": "mlp.c_fc1.weight",
"ff.linear_2.weight": "mlp.c_fc2.weight",
"ff.out_projection.weight": "mlp.c_proj.weight"
}
for k in block_map:
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
MAP_BASIC = {
("positional_encoding", "pos_embed.pos_embed"),
("register_tokens", "register_tokens"),
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
("cond_seq_linear.weight", "context_embedder.weight"),
("init_x_linear.weight", "pos_embed.proj.weight"),
("init_x_linear.bias", "pos_embed.proj.bias"),
("final_linear.weight", "proj_out.weight"),
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment