Commit a3dffc44 authored by comfyanonymous's avatar comfyanonymous
Browse files

Support AuraFlow Lora and loading model weights in diffusers format.

You can load model weights in diffusers format using the UNETLoader node.
parent ce2473bb
...@@ -274,4 +274,12 @@ def model_lora_keys_unet(model, key_map={}): ...@@ -274,4 +274,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
return key_map return key_map
...@@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -109,6 +109,10 @@ def detect_unet_config(state_dict, key_prefix):
unet_config = {} unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
unet_config["n_layers"] = double_layers + single_layers
return unet_config return unet_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
...@@ -450,37 +454,45 @@ def model_config_from_diffusers_unet(state_dict): ...@@ -450,37 +454,45 @@ def model_config_from_diffusers_unet(state_dict):
return None return None
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') out_sd = {}
if num_blocks > 0:
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
out_sd = {}
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
for k in sd_map: elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
weight = state_dict.get(k, None) num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
if weight is not None: num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
t = sd_map[k] sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else:
if not isinstance(t, str): return None
if len(t) > 2:
fun = t[2] for k in sd_map:
else: weight = state_dict.get(k, None)
fun = lambda a: a if weight is not None:
offset = t[1] t = sd_map[k]
if offset is not None:
old_weight = out_sd.get(t[0], None) if not isinstance(t, str):
if old_weight is None: if len(t) > 2:
old_weight = torch.empty_like(weight) fun = t[2]
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) else:
fun = lambda a: a
w = old_weight.narrow(offset[0], offset[1], offset[2]) offset = t[1]
else: if offset is not None:
old_weight = weight old_weight = out_sd.get(t[0], None)
w = weight if old_weight is None:
w[:] = fun(weight) old_weight = torch.empty_like(weight)
t = t[0] old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
out_sd[t] = old_weight
w = old_weight.narrow(offset[0], offset[1], offset[2])
else: else:
out_sd[t] = weight old_weight = weight
state_dict.pop(k) w = weight
w[:] = fun(weight)
t = t[0]
out_sd[t] = old_weight
else:
out_sd[t] = weight
state_dict.pop(k)
return out_sd return out_sd
...@@ -562,26 +562,25 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format ...@@ -562,26 +562,25 @@ def load_unet_state_dict(sd): #load unet in diffusers or regular format
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
elif 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3 else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is None: if new_sd is not None: #diffusers mmdit
return None model_config = model_detection.model_config_from_unet(new_sd, "")
model_config = model_detection.model_config_from_unet(new_sd, "") if model_config is None:
if model_config is None: return None
return None else: #diffusers unet
else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd)
model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None:
if model_config is None: return None
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
new_sd = {} for k in diffusers_keys:
for k in diffusers_keys: if k in sd:
if k in sd: new_sd[diffusers_keys[k]] = sd.pop(k)
new_sd[diffusers_keys[k]] = sd.pop(k) else:
else: logging.warning("{} {}".format(diffusers_keys[k], k))
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
......
...@@ -332,6 +332,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): ...@@ -332,6 +332,76 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("n_double_layers", 0)
n_layers = mmdit_config.get("n_layers", 0)
key_map = {}
for i in range(n_layers):
if i < n_double_layers:
index = i
prefix_from = "joint_transformer_blocks"
prefix_to = "{}double_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w2q.weight",
"attn.to_k.weight": "attn.w2k.weight",
"attn.to_v.weight": "attn.w2v.weight",
"attn.to_out.0.weight": "attn.w2o.weight",
"attn.add_q_proj.weight": "attn.w1q.weight",
"attn.add_k_proj.weight": "attn.w1k.weight",
"attn.add_v_proj.weight": "attn.w1v.weight",
"attn.to_add_out.weight": "attn.w1o.weight",
"ff.linear_1.weight": "mlpX.c_fc1.weight",
"ff.linear_2.weight": "mlpX.c_fc2.weight",
"ff.out_projection.weight": "mlpX.c_proj.weight",
"ff_context.linear_1.weight": "mlpC.c_fc1.weight",
"ff_context.linear_2.weight": "mlpC.c_fc2.weight",
"ff_context.out_projection.weight": "mlpC.c_proj.weight",
"norm1.linear.weight": "modX.1.weight",
"norm1_context.linear.weight": "modC.1.weight",
}
else:
index = i - n_double_layers
prefix_from = "single_transformer_blocks"
prefix_to = "{}single_layers".format(output_prefix)
block_map = {
"attn.to_q.weight": "attn.w1q.weight",
"attn.to_k.weight": "attn.w1k.weight",
"attn.to_v.weight": "attn.w1v.weight",
"attn.to_out.0.weight": "attn.w1o.weight",
"norm1.linear.weight": "modCX.1.weight",
"ff.linear_1.weight": "mlp.c_fc1.weight",
"ff.linear_2.weight": "mlp.c_fc2.weight",
"ff.out_projection.weight": "mlp.c_proj.weight"
}
for k in block_map:
key_map["{}.{}.{}".format(prefix_from, index, k)] = "{}.{}.{}".format(prefix_to, index, block_map[k])
MAP_BASIC = {
("positional_encoding", "pos_embed.pos_embed"),
("register_tokens", "register_tokens"),
("t_embedder.mlp.0.weight", "time_step_proj.linear_1.weight"),
("t_embedder.mlp.0.bias", "time_step_proj.linear_1.bias"),
("t_embedder.mlp.2.weight", "time_step_proj.linear_2.weight"),
("t_embedder.mlp.2.bias", "time_step_proj.linear_2.bias"),
("cond_seq_linear.weight", "context_embedder.weight"),
("init_x_linear.weight", "pos_embed.proj.weight"),
("init_x_linear.bias", "pos_embed.proj.bias"),
("final_linear.weight", "proj_out.weight"),
("modF.1.weight", "norm_out.linear.weight", swap_scale_shift),
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0): def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size: if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size) return tensor.narrow(dim, 0, batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment