Commit 218d333e authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Merge https://github.com/mit-han-lab/nunchaku into migrate_tinychat

parents 73939beb c7f41661
...@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained( ...@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev-int4.png") image.save("flux.1-dev.png")
import torch import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
......
import torch import torch
from diffusers import SanaPipeline from diffusers import SanaPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
......
__version__ = "0.0.2beta4" __version__ = "0.0.2beta6"
# convert the comfyui lora to diffusers format
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
for k, v in tensors.items():
if "alpha" in k:
continue
new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
if "lora_unet_double_blocks_" in k:
new_k = new_k.replace("lora_unet_double_blocks_", "transformer.transformer_blocks.")
if "qkv" in new_k:
for i, p in enumerate(["q", "k", "v"]):
if "lora_A" in new_k:
# Copy the tensor
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v.clone()
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v[i * chunk_size : (i + 1) * chunk_size]
else:
new_k = new_k.replace("_img_attn_proj", ".attn.to_out.0")
new_k = new_k.replace("_img_mlp_0", ".ff.net.0.proj")
new_k = new_k.replace("_img_mlp_2", ".ff.net.2")
new_k = new_k.replace("_img_mod_lin", ".norm1.linear")
new_k = new_k.replace("_txt_attn_proj", ".attn.to_add_out")
new_k = new_k.replace("_txt_mlp_0", ".ff_context.net.0.proj")
new_k = new_k.replace("_txt_mlp_2", ".ff_context.net.2")
new_k = new_k.replace("_txt_mod_lin", ".norm1_context.linear")
new_tensors[new_k] = v
else:
assert "lora_unet_single_blocks" in k
new_k = new_k.replace("lora_unet_single_blocks_", "transformer.single_transformer_blocks.")
if "linear1" in k:
start = 0
for i, p in enumerate(["q", "k", "v", "i"]):
if "lora_A" in new_k:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
new_tensors[new_k1] = v.clone()
else:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
chunk_size = 12288 if p == "i" else 3072
new_tensors[new_k1] = v[start : start + chunk_size]
start += chunk_size
else:
new_k = new_k.replace("_linear2", ".proj_out")
new_k = new_k.replace("_modulation_lin", ".norm.linear")
new_tensors[new_k] = v
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
import argparse
import os
import torch
from safetensors.torch import save_file
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .xlab_converter import xlab2diffusers
from ...utils import filter_state_dict, load_state_dict_in_safetensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--quant-path",
type=str,
help="path to the quantized model safetensor file",
default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
)
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument(
"--lora-format",
type=str,
default="diffusers",
choices=["comfyui", "diffusers", "xlab"],
help="format of the LoRA weights",
)
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0]
lora_name = "svdq-int4-" + lora_name
print(f"LoRA name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
lora_format = args.lora_format
if lora_format == "diffusers":
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path)
else:
if lora_format == "comfyui":
extra_lora_dict = comfyui2diffusers(args.lora_path)
elif lora_format == "xlab":
extra_lora_dict = xlab2diffusers(args.lora_path)
else:
raise NotImplementedError(f"LoRA format {lora_format} is not supported.")
extra_lora_dict = filter_state_dict(extra_lora_dict)
converted = convert_to_nunchaku_flux_lowrank_dict(
base_model=orig_state_dict,
lora=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format.""" """Convert LoRA weights to Nunchaku format."""
import argparse
import os
import typing as tp import typing as tp
import safetensors
import safetensors.torch
import torch import torch
import tqdm import tqdm
# region utilities from ...utils import ceil_divide, filter_state_dict, load_state_dict_in_safetensors
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns: # region utilities
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def pad( def pad(
...@@ -65,32 +49,6 @@ def update_state_dict( ...@@ -65,32 +49,6 @@ def update_state_dict(
return lhs return lhs
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
# endregion # endregion
...@@ -375,10 +333,20 @@ def convert_to_nunchaku_flux_transformer_block_lowrank_dict( ...@@ -375,10 +333,20 @@ def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
def convert_to_nunchaku_flux_lowrank_dict( def convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor], base_model: dict[str, torch.Tensor] | str,
extra_lora_dict: dict[str, torch.Tensor], lora: dict[str, torch.Tensor] | str,
default_dtype: torch.dtype = torch.bfloat16, default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if isinstance(base_model, str):
orig_state_dict = load_state_dict_in_safetensors(base_model)
else:
orig_state_dict = base_model
if isinstance(lora, str):
extra_lora_dict = load_state_dict_in_safetensors(lora, filter_prefix="transformer.")
else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
block_names: set[str] = set() block_names: set[str] = set()
for param_name in orig_state_dict.keys(): for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")): if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
...@@ -403,43 +371,3 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -403,43 +371,3 @@ def convert_to_nunchaku_flux_lowrank_dict(
prefix=block_name, prefix=block_name,
) )
return converted return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantized model safetensor file")
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
assert args.lora_path is not None, "LoRA name or path must be provided"
lora_name = args.lora_path.rstrip(os.sep).split(os.sep)[-1].replace(".safetensors", "")
print(f"Lora name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path, filter_prefix="transformer.")
converted = convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
safetensors.torch.save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
# convert the xlab lora to diffusers format
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def xlab2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
# lora1 is for img, lora2 is for text
for k, v in tensors.items():
assert "double_blocks" in k
new_k = k.replace("double_blocks", "transformer.transformer_blocks").replace("processor", "attn")
new_k = new_k.replace(".down.", ".lora_A.")
new_k = new_k.replace(".up.", ".lora_B.")
if ".proj_lora" in new_k:
new_k = new_k.replace(".proj_lora1", ".to_out.0")
new_k = new_k.replace(".proj_lora2", ".to_add_out")
new_tensors[new_k] = v
else:
assert "qkv_lora" in new_k
if "lora_A" in new_k:
for p in ["q", "k", "v"]:
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v.clone()
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v.clone()
else:
assert "lora_B" in new_k
for i, p in enumerate(["q", "k", "v"]):
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
...@@ -4,12 +4,13 @@ import diffusers ...@@ -4,12 +4,13 @@ import diffusers
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from huggingface_hub import hf_hub_download, utils from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from torch import nn from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils from .._C import QuantizedFluxModel, utils as cutils
from ..utils import fetch_or_download
SVD_RANK = 32 SVD_RANK = 32
...@@ -158,10 +159,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -158,10 +159,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return transformer return transformer
def update_lora_params(self, path: str): def update_lora_params(self, path: str):
if not os.path.exists(path): path = fetch_or_download(path)
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks) assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.load(path, True) block.m.load(path, True)
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from diffusers import __version__ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from typing import Optional, Any
class NunchakuModelLoaderMixin: class NunchakuModelLoaderMixin:
...@@ -66,10 +65,12 @@ class NunchakuModelLoaderMixin: ...@@ -66,10 +65,12 @@ class NunchakuModelLoaderMixin:
return transformer, transformer_block_path return transformer, transformer_block_path
def ceil_div(x: int, y: int) -> int: def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y return (x + y - 1) // y
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
def pad_tensor(tensor: torch.Tensor | None, multiples: int, dim: int, fill=0) -> torch.Tensor:
if multiples <= 1: if multiples <= 1:
return tensor return tensor
if tensor is None: if tensor is None:
...@@ -81,4 +82,4 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A ...@@ -81,4 +82,4 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill) result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result return result
\ No newline at end of file
import torch
from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
import os
import safetensors
import torch
from huggingface_hub import hf_hub_download
def fetch_or_download(path: str) -> str:
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
return path
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
"""Filter state dict.
Args:
state_dict (`dict`):
state dict.
filter_prefix (`str`):
filter prefix.
Returns:
`dict`:
filtered state dict.
"""
return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
[build-system] [build-system]
requires = [ requires = [
"setuptools", "setuptools",
"torch", "torch>=2.5",
"wheel", "wheel",
"ninja", "ninja",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["nunchaku"]
[project] [project]
dynamic = ["version"] dynamic = ["version"]
name = "nunchaku" name = "nunchaku"
dependencies = [ dependencies = [
"torch>=2.4.1", "diffusers>=0.32.2",
"diffusers>=0.30.3",
"transformers", "transformers",
"accelerate", "accelerate",
"sentencepiece", "sentencepiece",
"protobuf", "protobuf",
"huggingface_hub", "huggingface_hub",
] ]
\ No newline at end of file requires-python = ">=3.11, <3.13"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment