Commit e0fadc93 authored by muyangli's avatar muyangli
Browse files

[minor] fix some corner cases in lora conversion

parent d7896cb4
__version__ = "0.1.0"
__version__ = "0.1.1"
# convert the comfyui lora to diffusers format
import argparse
import os
import torch
......@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None, min_rank: int | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
......@@ -16,9 +17,10 @@ def comfyui2diffusers(
tensors = input_lora
new_tensors = {}
max_alpha = 0
for k, v in tensors.items():
if "alpha" in k:
max_alpha = max(max_alpha, v.max().item())
continue
new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
if "lora_unet_double_blocks_" in k:
......@@ -72,8 +74,36 @@ def comfyui2diffusers(
new_k = new_k.replace("_modulation_lin", ".norm.linear")
new_tensors[new_k] = v
if min_rank is not None:
for k in new_tensors.keys():
v = new_tensors[k]
if "lora_A" in k:
rank = v.shape[0]
if rank < min_rank:
new_v = torch.zeros(min_rank, v.shape[1], dtype=v.dtype, device=v.device)
new_v[:rank] = v
new_tensors[k] = new_v
else:
assert "lora_B" in k
rank = v.shape[1]
if rank < min_rank:
new_v = torch.zeros(v.shape[0], min_rank, dtype=v.dtype, device=v.device)
new_v[:, :rank] = v
new_tensors[k] = new_v
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensor file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensor file"
)
parser.add_argument("--min-rank", type=int, default=None, help="minimum rank for the LoRA weights")
args = parser.parse_args()
comfyui2diffusers(args.input_path, args.output_path, min_rank=args.min_rank)
......@@ -37,8 +37,8 @@ if __name__ == "__main__":
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
# output to the parent directory of the lora safetensor file
args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0]
......
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import typing as tp
import torch
......@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=3),
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
},
prefix=converted_local_name,
)
......@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=6),
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
},
prefix=converted_local_name,
)
......@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
for component in ["lora_A", "lora_B"]:
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
......@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
for k in extra_lora_dict.keys():
fc1_k = k
if "ff.net.0.proj" in k:
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
elif "ff_context.net.0.proj" in k:
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
else:
continue
assert fc2_k in extra_lora_dict
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
......@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
prefix=block_name,
)
return converted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment