Commit 218d333e authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Merge https://github.com/mit-han-lab/nunchaku into migrate_tinychat

parents 73939beb c7f41661
......@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev-int4.png")
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
......
import torch
from diffusers import SanaPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
......
__version__ = "0.0.2beta4"
__version__ = "0.0.2beta6"
# convert the comfyui lora to diffusers format
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
for k, v in tensors.items():
if "alpha" in k:
continue
new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
if "lora_unet_double_blocks_" in k:
new_k = new_k.replace("lora_unet_double_blocks_", "transformer.transformer_blocks.")
if "qkv" in new_k:
for i, p in enumerate(["q", "k", "v"]):
if "lora_A" in new_k:
# Copy the tensor
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v.clone()
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v[i * chunk_size : (i + 1) * chunk_size]
else:
new_k = new_k.replace("_img_attn_proj", ".attn.to_out.0")
new_k = new_k.replace("_img_mlp_0", ".ff.net.0.proj")
new_k = new_k.replace("_img_mlp_2", ".ff.net.2")
new_k = new_k.replace("_img_mod_lin", ".norm1.linear")
new_k = new_k.replace("_txt_attn_proj", ".attn.to_add_out")
new_k = new_k.replace("_txt_mlp_0", ".ff_context.net.0.proj")
new_k = new_k.replace("_txt_mlp_2", ".ff_context.net.2")
new_k = new_k.replace("_txt_mod_lin", ".norm1_context.linear")
new_tensors[new_k] = v
else:
assert "lora_unet_single_blocks" in k
new_k = new_k.replace("lora_unet_single_blocks_", "transformer.single_transformer_blocks.")
if "linear1" in k:
start = 0
for i, p in enumerate(["q", "k", "v", "i"]):
if "lora_A" in new_k:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
new_tensors[new_k1] = v.clone()
else:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
chunk_size = 12288 if p == "i" else 3072
new_tensors[new_k1] = v[start : start + chunk_size]
start += chunk_size
else:
new_k = new_k.replace("_linear2", ".proj_out")
new_k = new_k.replace("_modulation_lin", ".norm.linear")
new_tensors[new_k] = v
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
import argparse
import os
import torch
from safetensors.torch import save_file
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .xlab_converter import xlab2diffusers
from ...utils import filter_state_dict, load_state_dict_in_safetensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--quant-path",
type=str,
help="path to the quantized model safetensor file",
default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
)
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument(
"--lora-format",
type=str,
default="diffusers",
choices=["comfyui", "diffusers", "xlab"],
help="format of the LoRA weights",
)
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0]
lora_name = "svdq-int4-" + lora_name
print(f"LoRA name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
lora_format = args.lora_format
if lora_format == "diffusers":
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path)
else:
if lora_format == "comfyui":
extra_lora_dict = comfyui2diffusers(args.lora_path)
elif lora_format == "xlab":
extra_lora_dict = xlab2diffusers(args.lora_path)
else:
raise NotImplementedError(f"LoRA format {lora_format} is not supported.")
extra_lora_dict = filter_state_dict(extra_lora_dict)
converted = convert_to_nunchaku_flux_lowrank_dict(
base_model=orig_state_dict,
lora=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import argparse
import os
import typing as tp
import safetensors
import safetensors.torch
import torch
import tqdm
# region utilities
from ...utils import ceil_divide, filter_state_dict, load_state_dict_in_safetensors
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
# region utilities
def pad(
......@@ -65,32 +49,6 @@ def update_state_dict(
return lhs
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
# endregion
......@@ -375,10 +333,20 @@ def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
def convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
base_model: dict[str, torch.Tensor] | str,
lora: dict[str, torch.Tensor] | str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
if isinstance(base_model, str):
orig_state_dict = load_state_dict_in_safetensors(base_model)
else:
orig_state_dict = base_model
if isinstance(lora, str):
extra_lora_dict = load_state_dict_in_safetensors(lora, filter_prefix="transformer.")
else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
......@@ -403,43 +371,3 @@ def convert_to_nunchaku_flux_lowrank_dict(
prefix=block_name,
)
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantized model safetensor file")
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
assert args.lora_path is not None, "LoRA name or path must be provided"
lora_name = args.lora_path.rstrip(os.sep).split(os.sep)[-1].replace(".safetensors", "")
print(f"Lora name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path, filter_prefix="transformer.")
converted = convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
safetensors.torch.save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
# convert the xlab lora to diffusers format
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def xlab2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
# lora1 is for img, lora2 is for text
for k, v in tensors.items():
assert "double_blocks" in k
new_k = k.replace("double_blocks", "transformer.transformer_blocks").replace("processor", "attn")
new_k = new_k.replace(".down.", ".lora_A.")
new_k = new_k.replace(".up.", ".lora_B.")
if ".proj_lora" in new_k:
new_k = new_k.replace(".proj_lora1", ".to_out.0")
new_k = new_k.replace(".proj_lora2", ".to_add_out")
new_tensors[new_k] = v
else:
assert "qkv_lora" in new_k
if "lora_A" in new_k:
for p in ["q", "k", "v"]:
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v.clone()
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v.clone()
else:
assert "lora_B" in new_k
for i, p in enumerate(["q", "k", "v"]):
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
......@@ -4,12 +4,13 @@ import diffusers
import torch
from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config
from huggingface_hub import hf_hub_download, utils
from huggingface_hub import utils
from packaging.version import Version
from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils
from ..utils import fetch_or_download
SVD_RANK = 32
......@@ -158,10 +159,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return transformer
def update_lora_params(self, path: str):
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
path = fetch_or_download(path)
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.load(path, True)
......
......@@ -4,7 +4,6 @@ import torch
from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from typing import Optional, Any
class NunchakuModelLoaderMixin:
......@@ -66,10 +65,12 @@ class NunchakuModelLoaderMixin:
return transformer, transformer_block_path
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill=0) -> torch.Tensor:
if multiples <= 1:
return tensor
if tensor is None:
......@@ -81,4 +82,4 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
\ No newline at end of file
return result
import torch
from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
import os
import safetensors
import torch
from huggingface_hub import hf_hub_download
def fetch_or_download(path: str) -> str:
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
return path
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
"""Filter state dict.
Args:
state_dict (`dict`):
state dict.
filter_prefix (`str`):
filter prefix.
Returns:
`dict`:
filtered state dict.
"""
return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
[build-system]
requires = [
"setuptools",
"torch",
"torch>=2.5",
"wheel",
"ninja",
]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["nunchaku"]
[project]
dynamic = ["version"]
name = "nunchaku"
dependencies = [
"torch>=2.4.1",
"diffusers>=0.30.3",
"diffusers>=0.32.2",
"transformers",
"accelerate",
"sentencepiece",
"protobuf",
"huggingface_hub",
]
\ No newline at end of file
]
requires-python = ">=3.11, <3.13"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment