Commit e0fadc93 authored by muyangli's avatar muyangli
Browse files

[minor] fix some corner cases in lora conversion

parent d7896cb4
__version__ = "0.1.0" __version__ = "0.1.1"
# convert the comfyui lora to diffusers format # convert the comfyui lora to diffusers format
import argparse
import os import os
import torch import torch
...@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors ...@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers( def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None input_lora: str | dict[str, torch.Tensor], output_path: str | None = None, min_rank: int | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str): if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu") tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
...@@ -16,9 +17,10 @@ def comfyui2diffusers( ...@@ -16,9 +17,10 @@ def comfyui2diffusers(
tensors = input_lora tensors = input_lora
new_tensors = {} new_tensors = {}
max_alpha = 0
for k, v in tensors.items(): for k, v in tensors.items():
if "alpha" in k: if "alpha" in k:
max_alpha = max(max_alpha, v.max().item())
continue continue
new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B") new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
if "lora_unet_double_blocks_" in k: if "lora_unet_double_blocks_" in k:
...@@ -72,8 +74,36 @@ def comfyui2diffusers( ...@@ -72,8 +74,36 @@ def comfyui2diffusers(
new_k = new_k.replace("_modulation_lin", ".norm.linear") new_k = new_k.replace("_modulation_lin", ".norm.linear")
new_tensors[new_k] = v new_tensors[new_k] = v
if min_rank is not None:
for k in new_tensors.keys():
v = new_tensors[k]
if "lora_A" in k:
rank = v.shape[0]
if rank < min_rank:
new_v = torch.zeros(min_rank, v.shape[1], dtype=v.dtype, device=v.device)
new_v[:rank] = v
new_tensors[k] = new_v
else:
assert "lora_B" in k
rank = v.shape[1]
if rank < min_rank:
new_v = torch.zeros(v.shape[0], min_rank, dtype=v.dtype, device=v.device)
new_v[:, :rank] = v
new_tensors[k] = new_v
if output_path is not None: if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path)) output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path) save_file(new_tensors, output_path)
return new_tensors return new_tensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensor file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensor file"
)
parser.add_argument("--min-rank", type=int, default=None, help="minimum rank for the LoRA weights")
args = parser.parse_args()
comfyui2diffusers(args.input_path, args.output_path, min_rank=args.min_rank)
...@@ -37,8 +37,8 @@ if __name__ == "__main__": ...@@ -37,8 +37,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if not args.output_root: if not args.output_root:
# output to the parent directory of the quantized model safetensor file # output to the parent directory of the lora safetensor file
args.output_root = os.path.dirname(args.quant_path) args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None: if args.lora_name is None:
base_name = os.path.basename(args.lora_path) base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0] lora_name = base_name.rsplit(".", 1)[0]
......
# convert the diffusers lora to nunchaku format # convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format.""" """Convert LoRA weights to Nunchaku format."""
import typing as tp import typing as tp
import torch import torch
...@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 ...@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict( update_state_dict(
converted, converted,
{ {
"lora_down": lora[0], "lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": reorder_adanorm_lora_up(lora[1], splits=3), "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
}, },
prefix=converted_local_name, prefix=converted_local_name,
) )
...@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 ...@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict( update_state_dict(
converted, converted,
{ {
"lora_down": lora[0], "lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": reorder_adanorm_lora_up(lora[1], splits=6), "lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
}, },
prefix=converted_local_name, prefix=converted_local_name,
) )
...@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict( ...@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight") extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight") extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
for component in ["lora_A", "lora_B"]:
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
return convert_to_nunchaku_transformer_block_lowrank_dict( return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict, orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict, extra_lora_dict=extra_lora_dict,
...@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else: else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.") extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
for k in extra_lora_dict.keys():
fc1_k = k
if "ff.net.0.proj" in k:
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
elif "ff_context.net.0.proj" in k:
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
else:
continue
assert fc2_k in extra_lora_dict
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
block_names: set[str] = set() block_names: set[str] = set()
for param_name in orig_state_dict.keys(): for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")): if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
...@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
), ),
prefix=block_name, prefix=block_name,
) )
return converted return converted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment