Commit 3ef186fd authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

Multiple LoRAs

parent ca1a2e90
......@@ -17,8 +17,8 @@ _CITATION = """\
"""
_DESCRIPTION = """\
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
each with a complete description aiming to capture the full visual detail of what is present in the image.
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
each with a complete description aiming to capture the full visual detail of what is present in the image.
Much of the description is directly aligned to submasks of the image.
"""
......
......@@ -7,7 +7,7 @@ from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
......@@ -17,7 +17,7 @@ _CITATION = """\
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
......
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
### LoRA Related Code ###
transformer.update_lora_params(
"black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors"
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(0.85) # Your LoRA strength here
### End of LoRA Related Code ###
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0
).images[0]
image.save("int4-flux.1-canny-dev-lora.png")
import torch
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
### LoRA Related Code ###
transformer.update_lora_params(
"black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors"
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(0.85) # Your LoRA strength here
### End of LoRA Related Code ###
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt="A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.",
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10.0,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("int4-flux.1-depth-dev-lora.png")
......@@ -10,8 +10,8 @@ pipeline = FluxPipeline.from_pretrained(
### LoRA Related Code ###
transformer.update_lora_params(
"mit-han-lab/svdquant-lora-collection/svdq-int4-flux.1-dev-ghibsky.safetensors"
) # Path to your converted LoRA safetensors, can also be a remote HuggingFace path
"aleksa-codes/flux-ghibsky-illustration/lora.safetensors"
) # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(1) # Your LoRA strength here
### End of LoRA Related Code ###
......
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers
from .diffusers_converter import to_diffusers
from .nunchaku_converter import convert_to_nunchaku_flux_lowrank_dict, to_nunchaku
from .utils import is_nunchaku_format
# convert the comfyui lora to diffusers format
import argparse
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def comfyui2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None, min_rank: int | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
max_rank = 0
for k, v in tensors.items():
if "alpha" in k or "lora_te" in k:
continue
new_k = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
if "lora_unet_double_blocks_" in k:
new_k = new_k.replace("lora_unet_double_blocks_", "transformer.transformer_blocks.")
if "qkv" in new_k:
for i, p in enumerate(["q", "k", "v"]):
if "lora_A" in new_k:
# Copy the tensor
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else:
assert "lora_B" in new_k
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
new_k = new_k.replace("_img_attn_qkv", f".attn.to_{p}")
new_k = new_k.replace("_txt_attn_qkv", f".attn.add_{p}_proj")
new_tensors[new_k] = v[i * chunk_size : (i + 1) * chunk_size]
else:
new_k = new_k.replace("_img_attn_proj", ".attn.to_out.0")
new_k = new_k.replace("_img_mlp_0", ".ff.net.0.proj")
new_k = new_k.replace("_img_mlp_2", ".ff.net.2")
new_k = new_k.replace("_img_mod_lin", ".norm1.linear")
new_k = new_k.replace("_txt_attn_proj", ".attn.to_add_out")
new_k = new_k.replace("_txt_mlp_0", ".ff_context.net.0.proj")
new_k = new_k.replace("_txt_mlp_2", ".ff_context.net.2")
new_k = new_k.replace("_txt_mod_lin", ".norm1_context.linear")
if "lora_down" in k:
alpha = tensors[k.replace("lora_down.weight", "alpha")]
rank = v.shape[0]
v = v * alpha / rank
max_rank = max(max_rank, rank)
new_tensors[new_k] = v
else:
assert "lora_unet_single_blocks" in k
new_k = new_k.replace("lora_unet_single_blocks_", "transformer.single_transformer_blocks.")
if "linear1" in k:
start = 0
for i, p in enumerate(["q", "k", "v", "i"]):
if "lora_A" in new_k:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
new_tensors[new_k1] = v.clone() * alpha / rank
max_rank = max(max_rank, rank)
else:
if p == "i":
new_k1 = new_k.replace("_linear1", ".proj_mlp")
else:
new_k1 = new_k.replace("_linear1", f".attn.to_{p}")
chunk_size = 12288 if p == "i" else 3072
new_tensors[new_k1] = v[start : start + chunk_size]
start += chunk_size
else:
new_k = new_k.replace("_linear2", ".proj_out")
new_k = new_k.replace("_modulation_lin", ".norm.linear")
if "lora_down" in k:
rank = v.shape[0]
alpha = tensors[k.replace("lora_down.weight", "alpha")]
v = v * alpha / rank
max_rank = max(max_rank, rank)
new_tensors[new_k] = v
if min_rank is not None:
for k in new_tensors.keys():
v = new_tensors[k]
if "lora_A" in k:
rank = v.shape[0]
if rank < min_rank:
new_v = torch.zeros(min_rank, v.shape[1], dtype=v.dtype, device=v.device)
new_v[:rank] = v
new_tensors[k] = new_v
else:
assert "lora_B" in k
rank = v.shape[1]
if rank < min_rank:
new_v = torch.zeros(v.shape[0], min_rank, dtype=v.dtype, device=v.device)
new_v[:, :rank] = v
new_tensors[k] = new_v
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensor file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensor file"
)
parser.add_argument("--min-rank", type=int, default=None, help="minimum rank for the LoRA weights")
args = parser.parse_args()
comfyui2diffusers(args.input_path, args.output_path, min_rank=args.min_rank)
import argparse
import os
import torch
from safetensors.torch import save_file
from .diffusers_converter import to_diffusers
from .utils import is_nunchaku_format
def compose_lora(
loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
) -> dict[str, torch.Tensor]:
composed = {}
for lora, strength in loras:
assert not is_nunchaku_format(lora)
lora = to_diffusers(lora)
for k, v in list(lora.items()):
if v.ndim == 1:
previous_tensor = composed.get(k, None)
if previous_tensor is None:
if "norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k:
composed[k] = v
else:
composed[k] = v * strength
else:
assert not ("norm_q" in k or "norm_k" in k or "norm_added_q" in k or "norm_added_k" in k)
composed[k] = previous_tensor + v * strength
else:
assert v.ndim == 2
if "lora_A" in k:
v = v * strength
if ".to_q." in k or ".add_q_proj." in k: # qkv must all exist
if "lora_B" in k:
continue
q_a = v
k_a = lora[k.replace(".to_q.", ".to_k.").replace(".add_q_proj.", ".add_k_proj.")]
v_a = lora[k.replace(".to_q.", ".to_v.").replace(".add_q_proj.", ".add_v_proj.")]
q_b = lora[k.replace("lora_A", "lora_B")]
k_b = lora[
k.replace("lora_A", "lora_B")
.replace(".to_q.", ".to_k.")
.replace(".add_q_proj.", ".add_k_proj.")
]
v_b = lora[
k.replace("lora_A", "lora_B")
.replace(".to_q.", ".to_v.")
.replace(".add_q_proj.", ".add_v_proj.")
]
assert q_a.shape[0] == k_a.shape[0] == v_a.shape[0]
assert q_b.shape[1] == k_b.shape[1] == v_b.shape[1]
if torch.isclose(q_a, k_a).all() and torch.isclose(q_a, v_a).all():
lora_a = q_a
lora_b = torch.cat((q_b, k_b, v_b), dim=0)
else:
lora_a_group = (q_a, k_a, v_a)
new_shape_a = [sum([_.shape[0] for _ in lora_a_group]), q_a.shape[1]]
lora_a = torch.zeros(new_shape_a, dtype=q_a.dtype, device=q_a.device)
start_dim = 0
for tensor in lora_a_group:
lora_a[start_dim : start_dim + tensor.shape[0]] = tensor
start_dim += tensor.shape[0]
lora_b_group = (q_b, k_b, v_b)
new_shape_b = [sum([_.shape[0] for _ in lora_b_group]), sum([_.shape[1] for _ in lora_b_group])]
lora_b = torch.zeros(new_shape_b, dtype=q_b.dtype, device=q_b.device)
start_dims = (0, 0)
for tensor in lora_b_group:
end_dims = (start_dims[0] + tensor.shape[0], start_dims[1] + tensor.shape[1])
lora_b[start_dims[0] : end_dims[0], start_dims[1] : end_dims[1]] = tensor
start_dims = end_dims
lora_a = lora_a * strength
new_k_a = k.replace(".to_q.", ".to_qkv.").replace(".add_q_proj.", ".add_qkv_proj.")
new_k_b = new_k_a.replace("lora_A", "lora_B")
for kk, vv, dim in ((new_k_a, lora_a, 0), (new_k_b, lora_b, 1)):
previous_lora = composed.get(kk, None)
composed[kk] = vv if previous_lora is None else torch.cat([previous_lora, vv], dim=dim)
elif ".to_k." in k or ".to_v." in k or ".add_k_proj." in k or ".add_v_proj." in k:
continue
else:
if "lora_A" in k:
v = v * strength
previous_lora = composed.get(k, None)
if previous_lora is None:
composed[k] = v
else:
if "lora_A" in k:
if previous_lora.shape[1] != v.shape[1]: # flux.1-tools LoRA compatibility
assert "x_embedder" in k
expanded_size = max(previous_lora.shape[1], v.shape[1])
if expanded_size > previous_lora.shape[1]:
expanded_previous_lora = torch.zeros(
(previous_lora.shape[0], expanded_size),
device=previous_lora.device,
dtype=previous_lora.dtype,
)
expanded_previous_lora[:, : previous_lora.shape[1]] = previous_lora
else:
expanded_previous_lora = previous_lora
if expanded_size > v.shape[1]:
expanded_v = torch.zeros(
(v.shape[0], expanded_size), device=v.device, dtype=v.dtype
)
expanded_v[:, : v.shape[1]] = v
else:
expanded_v = v
composed[k] = torch.cat([expanded_previous_lora, expanded_v], dim=0)
else:
composed[k] = torch.cat([previous_lora, v], dim=0)
else:
composed[k] = torch.cat([previous_lora, v], dim=1)
composed[k] = (
v if previous_lora is None else torch.cat([previous_lora, v], dim=0 if "lora_A" in k else 1)
)
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(composed, output_path)
return composed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input-paths", type=str, nargs="*", required=True, help="paths to the lora safetensors files"
)
parser.add_argument("-s", "--strengths", type=float, nargs="*", required=True, help="strengths for each lora")
parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file")
args = parser.parse_args()
assert len(args.input_paths) == len(args.strengths)
composed = compose_lora(list(zip(args.input_paths, args.strengths)))
import argparse
import os
import torch
from safetensors.torch import save_file
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers
from ...utils import filter_state_dict, load_state_dict_in_safetensors
from .nunchaku_converter import to_nunchaku
from .utils import is_nunchaku_format
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -19,13 +13,6 @@ if __name__ == "__main__":
default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
)
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument(
"--lora-format",
type=str,
default="auto",
choices=["auto", "comfyui", "diffusers", "xlab"],
help="format of the LoRA weights",
)
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
......@@ -37,46 +24,26 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if is_nunchaku_format(args.lora_path):
print("Already in nunchaku format, no conversion needed.")
exit(0)
if not args.output_root:
# output to the parent directory of the lora safetensor file
# output to the parent directory of the lora safetensors file
args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
lora_name = base_name.rsplit(".", 1)[0]
lora_name = "svdq-int4-" + lora_name
precision = "fp4" if "fp4" in args.quant_path else "int4"
lora_name = f"svdq-{precision}-{lora_name}"
print(f"LoRA name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
lora_format = args.lora_format
if lora_format == "auto":
lora_format = detect_format(args.lora_path)
print(f"Detected LoRA format: {lora_format}")
if lora_format == "svdquant":
print("Already in SVDQuant format, no conversion needed.")
exit(0)
if lora_format == "diffusers":
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path)
else:
if lora_format == "comfyui":
extra_lora_dict = comfyui2diffusers(args.lora_path)
elif lora_format == "xlab":
extra_lora_dict = xlab2diffusers(args.lora_path)
else:
raise NotImplementedError(f"LoRA format {lora_format} is not supported.")
extra_lora_dict = filter_state_dict(extra_lora_dict)
converted = convert_to_nunchaku_flux_lowrank_dict(
base_model=orig_state_dict,
lora=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
to_nunchaku(
args.lora_path,
args.quant_path,
dtype=args.dtype,
output_path=os.path.join(args.output_root, f"{lora_name}.safetensors"),
)
os.makedirs(args.output_root, exist_ok=True)
save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import typing as tp
import argparse
import os
import warnings
import torch
import tqdm
from diffusers.loaders import FluxLoraLoaderMixin
from safetensors.torch import save_file
from ...utils import ceil_divide, filter_state_dict, load_state_dict_in_safetensors
from .utils import load_state_dict_in_safetensors
# region utilities
def pad(
tensor: tp.Optional[torch.Tensor],
divisor: int | tp.Sequence[int],
dim: int | tp.Sequence[int],
fill_value: float | int = 0,
) -> torch.Tensor | None:
if isinstance(divisor, int):
if divisor <= 1:
return tensor
elif all(d <= 1 for d in divisor):
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if isinstance(dim, int):
assert isinstance(divisor, int)
shape[dim] = ceil_divide(shape[dim], divisor) * divisor
else:
if isinstance(divisor, int):
divisor = [divisor] * len(dim)
for d, div in zip(dim, divisor, strict=True):
shape[d] = ceil_divide(shape[d], div) * div
result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def update_state_dict(
lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
) -> dict[str, torch.Tensor]:
for rkey, value in rhs.items():
lkey = f"{prefix}.{rkey}" if prefix else rkey
assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
lhs[lkey] = value
return lhs
# endregion
def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size, k_pack_size = 2, 2
num_n_lanes, num_k_lanes = 8, 4
frag_n = n_pack_size * num_n_lanes * lane_n
frag_k = k_pack_size * num_k_lanes * lane_k
weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1))
if down:
r, c = weight.shape
r_frags, c_frags = r // frag_n, c // frag_k
weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
else:
c, r = weight.shape
c_frags, r_frags = c // frag_n, r // frag_k
weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
return weight.view(c, r)
def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size, k_pack_size = 2, 2
num_n_lanes, num_k_lanes = 8, 4
frag_n = n_pack_size * num_n_lanes * lane_n
frag_k = k_pack_size * num_k_lanes * lane_k
if down:
r_frags, c_frags = r // frag_n, c // frag_k
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
c_frags, r_frags = c // frag_n, r // frag_k
weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k)
weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
weight = weight.view(c_frags, r_frags, frag_n, frag_k)
if down:
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
else:
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
return weight
tensors = input_lora
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
c, r = lora_up.shape
assert c % splits == 0
return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.")
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
local_name_map: dict[str, str | list[str]],
convert_map: dict[str, str],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
print(f"Converting LoRA branch for block {candidate_block_name}...")
converted: dict[str, torch.Tensor] = {}
for converted_local_name, candidate_local_names in tqdm.tqdm(
local_name_map.items(), desc=f"Converting {candidate_block_name}", dynamic_ncols=True
):
if isinstance(candidate_local_names, str):
candidate_local_names = [candidate_local_names]
# region original LoRA
orig_lora = (
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
)
if orig_lora[0] is None or orig_lora[1] is None:
assert orig_lora[0] is None and orig_lora[1] is None
orig_lora = None
else:
assert orig_lora[0] is not None and orig_lora[1] is not None
orig_lora = (
unpack_lowrank_weight(orig_lora[0], down=True),
unpack_lowrank_weight(orig_lora[1], down=False),
)
print(f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})")
# endregion
# region extra LoRA
extra_lora = [
(
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
)
for candidate_local_name in candidate_local_names
]
if any(lora[0] is not None or lora[1] is not None for lora in extra_lora):
# merge extra LoRAs into one LoRA
if len(extra_lora) > 1:
first_lora = None
for lora in extra_lora:
if lora[0] is not None:
assert lora[1] is not None
first_lora = lora
break
assert first_lora is not None
for lora_index in range(len(extra_lora)):
if extra_lora[lora_index][0] is None:
assert extra_lora[lora_index][1] is None
extra_lora[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
if all(lora[0].equal(extra_lora[0][0]) for lora in extra_lora):
# if all extra LoRAs have the same lora_down, use it
extra_lora_down = extra_lora[0][0]
extra_lora_up = torch.cat([lora[1] for lora in extra_lora], dim=0)
else:
extra_lora_down = torch.cat([lora[0] for lora in extra_lora], dim=0)
extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora)
extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora)
assert extra_lora_up_r == extra_lora_down.shape[0]
extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
c, r = 0, 0
for lora in extra_lora:
c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
extra_lora_up[c:c_next, r:r_next] = lora[1]
c, r = c_next, r_next
else:
extra_lora_down, extra_lora_up = extra_lora[0]
extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
print(f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})")
else:
extra_lora = None
# endregion
# region merge LoRA
if orig_lora is None:
if extra_lora is None:
lora = None
else:
print(" - Using extra LoRA")
lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
elif extra_lora is None:
print(" - Using original LoRA")
lora = orig_lora
else:
lora = (
torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0),
torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1),
)
print(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
# endregion
if lora is not None:
if convert_map[converted_local_name] == "adanorm_single":
update_state_dict(
converted,
{
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "adanorm_zero":
update_state_dict(
converted,
{
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "linear":
update_state_dict(
converted,
{
"lora_down": pack_lowrank_weight(lora[0], down=True),
"lora_up": pack_lowrank_weight(lora[1], down=False),
},
prefix=converted_local_name,
)
return converted
return new_tensors
def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
assert lora_down.shape[1] == n1 + n2
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
for component in ["lora_A", "lora_B"]:
fc1_k = f"{candidate_block_name}.proj_mlp.{component}.weight"
fc2_k = f"{candidate_block_name}.proj_out.linears.1.{component}.weight"
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm.linear": "norm.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "proj_mlp",
"mlp_fc2": "proj_out.linears.1",
},
convert_map={
"norm.linear": "adanorm_single",
"qkv_proj": "linear",
"out_proj": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
},
default_dtype=default_dtype,
)
def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm1.linear": "norm1.linear",
"norm1_context.linear": "norm1_context.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"norm_added_q": "attn.norm_added_q",
"norm_added_k": "attn.norm_added_k",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_add_out",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": "ff.net.2",
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": "ff_context.net.2",
},
convert_map={
"norm1.linear": "adanorm_zero",
"norm1_context.linear": "adanorm_zero",
"qkv_proj": "linear",
"qkv_proj_context": "linear",
"out_proj": "linear",
"out_proj_context": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
"mlp_context_fc1": "linear",
"mlp_context_fc2": "linear",
},
default_dtype=default_dtype,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-path", type=str, required=True, help="path to the comfyui lora safetensors file")
parser.add_argument(
"-o", "--output-path", type=str, required=True, help="path to the output diffusers safetensors file"
)
def convert_to_nunchaku_flux_lowrank_dict(
base_model: dict[str, torch.Tensor] | str,
lora: dict[str, torch.Tensor] | str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
if isinstance(base_model, str):
orig_state_dict = load_state_dict_in_safetensors(base_model)
else:
orig_state_dict = base_model
if isinstance(lora, str):
extra_lora_dict = load_state_dict_in_safetensors(lora, filter_prefix="transformer.")
else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
unquantized_lora_dict = {}
for k in list(extra_lora_dict.keys()):
if "transformer_blocks" not in k:
unquantized_lora_dict[k] = extra_lora_dict.pop(k)
for k in extra_lora_dict.keys():
fc1_k = k
if "ff.net.0.proj" in k:
fc2_k = k.replace("ff.net.0.proj", "ff.net.2")
elif "ff_context.net.0.proj" in k:
fc2_k = k.replace("ff_context.net.0.proj", "ff_context.net.2")
else:
continue
assert fc2_k in extra_lora_dict
fc1_v = extra_lora_dict[fc1_k]
fc2_v = extra_lora_dict[fc2_k]
dim = 0 if "lora_A" in fc1_k else 1
fc1_rank = fc1_v.shape[dim]
fc2_rank = fc2_v.shape[dim]
if fc1_rank != fc2_rank:
rank = max(fc1_rank, fc2_rank)
if fc1_rank < rank:
extra_lora_dict[fc1_k] = pad(fc1_v, divisor=rank, dim=dim)
if fc2_rank < rank:
extra_lora_dict[fc2_k] = pad(fc2_v, divisor=rank, dim=dim)
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
block_names.add(".".join(param_name.split(".")[:2]))
block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
print(f"Converting {len(block_names)} transformer blocks...")
converted: dict[str, torch.Tensor] = {}
for block_name in block_names:
if block_name.startswith("transformer_blocks"):
convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
else:
convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
update_state_dict(
converted,
convert_fn(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=block_name,
candidate_block_name=block_name,
default_dtype=default_dtype,
),
prefix=block_name,
)
converted.update(unquantized_lora_dict)
return converted
args = parser.parse_args()
to_diffusers(args.input_path, args.output_path)
This diff is collapsed.
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import torch
from .utils import pad
from ...utils import ceil_divide
class MmaWeightPackerBase:
def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
self.bits = bits
assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
# region compute tile size
self.comp_n = comp_n if comp_n is not None else 16
"""smallest tile size in `n` dimension for MMA computation."""
self.comp_k = comp_k if comp_k is not None else 256 // self.bits
"""smallest tile size in `k` dimension for MMA computation."""
# the smallest MMA computation may contain several MMA instructions
self.insn_n = 8 # mma instruction tile size in `n` dimension
"""tile size in `n` dimension for MMA instruction."""
self.insn_k = self.comp_k
"""tile size in `k` dimension for MMA instruction."""
assert self.insn_k * self.bits in (
128,
256,
), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
self.num_lanes = 32
"""there are 32 lanes (or threds) in a warp."""
self.num_k_lanes = 4
self.num_n_lanes = 8
assert (
warp_n >= self.comp_n and warp_n % self.comp_n == 0
), f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})."
self.warp_n = warp_n
# endregion
# region memory
self.reg_k = 32 // self.bits
"""number of elements in a register in `k` dimension."""
self.reg_n = 1
"""number of elements in a register in `n` dimension (always 1)."""
self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
"""number of elements in a pack in `k` dimension."""
self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
"""number of elements in a pack in `n` dimension."""
self.pack_size = self.k_pack_size * self.n_pack_size
"""number of elements in a pack accessed by a lane at a time."""
assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
self.mem_k = self.comp_k
"""the tile size in `k` dimension for one tensor memory access."""
self.mem_n = warp_n
"""the tile size in `n` dimension for one tensor memory access."""
self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
"""number of packs in `k` dimension for one tensor memory access."""
self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
"""number of packs in `n` dimension for one tensor memory access."""
# endregion
def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
return (
n // self.mem_n,
self.num_n_packs,
self.n_pack_size,
self.num_n_lanes,
self.reg_n,
k // self.mem_k,
self.num_k_packs,
self.k_pack_size,
self.num_k_lanes,
self.reg_k,
)
class NunchakuWeightPacker(MmaWeightPackerBase):
def __init__(self, bits: int, warp_n: int = 128):
super().__init__(bits=bits, warp_n=warp_n)
self.num_k_unrolls = 2
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
n, k = weight.shape
assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
# currently, Nunchaku did not check the boundry of unrolled `k` dimension
assert k % (self.mem_k * self.num_k_unrolls) == 0, (
f"input channel size ({k}) should be divisible by "
f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})."
)
n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
weight = weight.reshape(
n_tiles,
self.num_n_packs, # 8 when warp_n = 128
self.n_pack_size, # always 2 in nunchaku
self.num_n_lanes, # constant 8
self.reg_n, # constant 1
k_tiles,
self.num_k_packs, # 1
self.k_pack_size, # always 2 in nunchaku
self.num_k_lanes, # constant 4
self.reg_k, # always 8 = 32 bits / 4 bits
)
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
# =>
# (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
assert weight.shape[4:-2] == (8, 4, 2, 2)
if self.bits == 4:
weight = weight.bitwise_and_(0xF)
shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
weight = weight.bitwise_left_shift_(shift)
weight = weight.sum(dim=-1, dtype=torch.int32)
elif self.bits == 8:
weight = weight.bitwise_and_(0xFF)
shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device)
weight = weight.bitwise_left_shift_(shift)
weight = weight.sum(dim=-1, dtype=torch.int32)
else:
raise NotImplementedError(f"weight bits {self.bits} is not supported.")
return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
if self.check_if_micro_scale(group_size=group_size):
return self.pack_micro_scale(scale, group_size=group_size)
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
n = scale.shape[0]
# nunchaku load scales all in one access
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 16 bit as it contains 1 fp16
# min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
# max `s_pack_size` set to 128b/16b = 8 elements
# for `warp_n = 8`, we have
# `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8)
num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size)
num_s_packs = self.warp_n // (s_pack_size * num_s_lanes)
warp_s = num_s_packs * num_s_lanes * s_pack_size
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
# `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
# however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
# thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
# we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
# finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
# the process is repeated for `num_s_packs` times
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
# 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
# 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
# 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
# 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
# ... ...
# 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1)
scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous()
return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
assert scale.max() <= 448, "scale should be less than 448."
assert scale.min() >= -448, "scale should be greater than -448."
assert group_size == 16, "currently only support group size 16."
assert self.insn_k == 64, "insn_k should be 64."
scale = scale.to(dtype=torch.float8_e4m3fn)
n = scale.shape[0]
assert self.warp_n >= 32, "currently only support warp_n >= 32."
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
# min `s_pack_size` set to 1 element
# max `s_pack_size` set to 128b/32b = 4 elements
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4)
num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes)
warp_s = num_s_packs * num_s_lanes * s_pack_size
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
# we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
# we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 32 64 96 <-- load by lane 0
# 8 40 72 104 <-- load by lane 1
# 16 48 80 112 <-- load by lane 2
# 24 56 88 120 <-- load by lane 3
# 1 33 65 97 <-- load by lane 4
# ...
# 25 57 81 113 <-- load by lane 7
# ...
# 7 39 71 103 <-- load by lane 28
# ...
# 31 63 95 127 <-- load by lane 31
scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size)
scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous()
return scale.view(-1, n) # the shape is just used for validation
def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1))
if down:
r, c = weight.shape
r_packs, c_packs = r // pack_n, c // pack_k
weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3)
else:
c, r = weight.shape
c_packs, r_packs = c // pack_n, r // pack_k
weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3)
weight = weight.reshape(
c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k
)
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
# =>
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous()
return weight.view(c, r)
def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
if down:
r_packs, c_packs = r // pack_n, c // pack_k
else:
c_packs, r_packs = c // pack_n, r // pack_k
weight = weight.view(
c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k
)
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
# =>
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous()
weight = weight.view(c_packs, r_packs, pack_n, pack_k)
if down:
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
else:
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
return weight
def check_if_micro_scale(self, group_size: int) -> bool:
return self.insn_k == group_size * 4
def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor:
if group_size > 0 and scale.numel() > scale.shape[0]:
scale = scale.view(scale.shape[0], 1, -1, 1)
if self.check_if_micro_scale(group_size=group_size):
scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=fill_value)
else:
scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=fill_value)
else:
scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=fill_value)
return scale
def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
import typing as tp
import torch
from ...utils import load_state_dict_in_safetensors
from ...utils import ceil_divide, load_state_dict_in_safetensors
def detect_format(lora: str | dict[str, torch.Tensor]) -> str:
def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
if isinstance(lora, str):
tensors = load_state_dict_in_safetensors(lora, device="cpu")
else:
tensors = lora
for k in tensors.keys():
if "lora_unet_double_blocks_" in k or "lora_unet_single_blocks" in k:
return "comfyui"
elif ".mlp_fc" in k or "mlp_context_fc1" in k:
return "svdquant"
elif "double_blocks." in k or "single_blocks." in k:
return "xlab"
elif "transformer." in k:
return "diffusers"
raise ValueError("Unknown format, please provide the format explicitly.")
if ".mlp_fc" in k or "mlp_context_fc1" in k:
return True
return False
def pad(
tensor: tp.Optional[torch.Tensor],
divisor: int | tp.Sequence[int],
dim: int | tp.Sequence[int],
fill_value: float | int = 0,
) -> torch.Tensor | None:
if isinstance(divisor, int):
if divisor <= 1:
return tensor
elif all(d <= 1 for d in divisor):
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if isinstance(dim, int):
assert isinstance(divisor, int)
shape[dim] = ceil_divide(shape[dim], divisor) * divisor
else:
if isinstance(divisor, int):
divisor = [divisor] * len(dim)
for d, div in zip(dim, divisor, strict=True):
shape[d] = ceil_divide(shape[d], div) * div
result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
# convert the xlab lora to diffusers format
import os
import torch
from safetensors.torch import save_file
from ...utils import load_state_dict_in_safetensors
def xlab2diffusers(
input_lora: str | dict[str, torch.Tensor], output_path: str | None = None
) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = input_lora
new_tensors = {}
# lora1 is for img, lora2 is for text
for k, v in tensors.items():
assert "double_blocks" in k
new_k = k.replace("double_blocks", "transformer.transformer_blocks").replace("processor", "attn")
new_k = new_k.replace(".down.", ".lora_A.")
new_k = new_k.replace(".up.", ".lora_B.")
if ".proj_lora" in new_k:
new_k = new_k.replace(".proj_lora1", ".to_out.0")
new_k = new_k.replace(".proj_lora2", ".to_add_out")
new_tensors[new_k] = v
else:
assert "qkv_lora" in new_k
if "lora_A" in new_k:
for p in ["q", "k", "v"]:
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v.clone()
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v.clone()
else:
assert "lora_B" in new_k
for i, p in enumerate(["q", "k", "v"]):
assert v.shape[0] % 3 == 0
chunk_size = v.shape[0] // 3
if ".qkv_lora1." in new_k:
new_tensors[new_k.replace(".qkv_lora1.", f".to_{p}.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
else:
assert ".qkv_lora2." in new_k
new_tensors[new_k.replace(".qkv_lora2.", f".add_{p}_proj.")] = v[
i * chunk_size : (i + 1) * chunk_size
]
if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path))
os.makedirs(output_dir, exist_ok=True)
save_file(new_tensors, output_path)
return new_tensors
......@@ -120,4 +120,4 @@ def convert_to_tinychat_w4x16y16_linear_weight(
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
\ No newline at end of file
return _weight, _scale, _zero
import logging
import os
import diffusers
......@@ -6,14 +7,24 @@ from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config
from huggingface_hub import utils
from packaging.version import Version
from safetensors.torch import load_file
from torch import nn
from .utils import NunchakuModelLoaderMixin, pad_tensor
from .utils import get_precision, NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format
from ...utils import load_state_dict_in_safetensors
SVD_RANK = 32
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuFluxTransformerBlocks(nn.Module):
def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
......@@ -35,9 +46,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotemb = rotemb.permute(0, 1, 3, 2, 4)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
##########################################|--M--|--D--|
##########################################|--M--|--D--|
##########################################|-3--4--5--6|
########################################## : : : :
########################################## : : : :
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
rotemb = rotemb.contiguous()
......@@ -208,8 +219,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
patch_size=patch_size,
in_channels=in_channels,
out_channels=out_channels,
num_layers=0,
num_single_layers=0,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
......@@ -217,76 +228,201 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
)
self.unquantized_loras = {}
self.unquantized_state_dict = None
# these state_dicts are used for supporting lora
self._unquantized_part_sd: dict[str, torch.Tensor] = {}
self._unquantized_part_loras: dict[str, torch.Tensor] = {}
self._quantized_part_sd: dict[str, torch.Tensor] = {}
self._quantized_part_vectors: dict[str, torch.Tensor] = {}
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4")
if isinstance(device, str):
device = torch.device(device)
offload = kwargs.get("offload", False)
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
new_quantized_part_sd = {}
for k, v in quantized_part_sd.items():
if v.ndim == 1:
new_quantized_part_sd[k] = v
elif "qweight" in k:
# only the shape information of this tensor is needed
new_quantized_part_sd[k] = v.to("meta")
elif "lora" in k:
new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4", offload=offload)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_part_sd = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_part_sd, strict=False)
transformer._unquantized_part_sd = unquantized_part_sd
return transformer
def update_unquantized_lora_params(self, strength: float = 1):
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module")
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
### Compatible with the original forward method
self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
self.single_transformer_blocks = nn.ModuleList([])
return self
def set_attention_impl(self, impl: str):
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setAttentionImpl(impl)
### LoRA Related Functions
def _expand_module(self, module_name: str, new_shape: tuple[int, int]):
module = self.get_submodule(module_name)
assert isinstance(module, nn.Linear)
weight_shape = module.weight.shape
logger.info("Expand the shape of module {} from {} to {}".format(module_name, tuple(weight_shape), new_shape))
assert new_shape[0] >= weight_shape[0] and new_shape[1] >= weight_shape[1]
new_module = nn.Linear(
new_shape[1],
new_shape[0],
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
)
new_module.weight.data.zero_()
new_module.weight.data[: weight_shape[0], : weight_shape[1]] = module.weight.data
self._unquantized_part_sd[f"{module_name}.weight"] = new_module.weight.data.clone()
if new_module.bias is not None:
new_module.bias.data.zero_()
new_module.bias.data[: weight_shape[0]] = module.bias.data
self._unquantized_part_sd[f"{module_name}.bias"] = new_module.bias.data.clone()
parent_name = ".".join(module_name.split(".")[:-1])
parent_module = self.get_submodule(parent_name)
parent_module.add_module(module_name.split(".")[-1], new_module)
if module_name == "x_embedder":
new_value = int(new_module.weight.data.shape[1])
old_value = getattr(self.config, "in_channels")
if new_value != old_value:
logger.info(f"Update in_channels from {old_value} to {new_value}")
setattr(self.config, "in_channels", new_value)
def _update_unquantized_part_lora_params(self, strength: float = 1):
# check if we need to expand the linear layers
device = next(self.parameters()).device
for k, v in self._unquantized_part_loras.items():
if "lora_A" in k:
lora_a = v
lora_b = self._unquantized_part_loras[k.replace(".lora_A.", ".lora_B.")]
diff_shape = (lora_b.shape[0], lora_a.shape[1])
weight_shape = self._unquantized_part_sd[k.replace(".lora_A.", ".")].shape
if diff_shape[0] > weight_shape[0] or diff_shape[1] > weight_shape[1]:
module_name = ".".join(k.split(".")[:-2])
self._expand_module(module_name, diff_shape)
elif v.ndim == 1:
diff_shape = v.shape
weight_shape = self._unquantized_part_sd[k].shape
if diff_shape[0] > weight_shape[0]:
assert diff_shape[0] >= weight_shape[0]
module_name = ".".join(k.split(".")[:-1])
module = self.get_submodule(module_name)
weight_shape = module.weight.shape
diff_shape = (diff_shape[0], weight_shape[1])
self._expand_module(module_name, diff_shape)
new_state_dict = {}
for k in self.unquantized_state_dict.keys():
v = self.unquantized_state_dict[k]
if k.replace(".weight", ".lora_B.weight") in self.unquantized_loras:
new_state_dict[k] = v + strength * (
self.unquantized_loras[k.replace(".weight", ".lora_B.weight")]
@ self.unquantized_loras[k.replace(".weight", ".lora_A.weight")]
)
for k in self._unquantized_part_sd.keys():
v = self._unquantized_part_sd[k]
v = v.to(device)
self._unquantized_part_sd[k] = v
if v.ndim == 1 and k in self._unquantized_part_loras:
diff = strength * self._unquantized_part_loras[k]
if diff.shape[0] < v.shape[0]:
diff = torch.cat(
[diff, torch.zeros(v.shape[0] - diff.shape[0], device=device, dtype=v.dtype)], dim=0
)
new_state_dict[k] = v + diff
elif v.ndim == 2 and k.replace(".weight", ".lora_B.weight") in self._unquantized_part_loras:
lora_a = self._unquantized_part_loras[k.replace(".weight", ".lora_A.weight")]
lora_b = self._unquantized_part_loras[k.replace(".weight", ".lora_B.weight")]
if lora_a.shape[1] < v.shape[1]:
lora_a = torch.cat(
[
lora_a,
torch.zeros(lora_a.shape[0], v.shape[1] - lora_a.shape[1], device=device, dtype=v.dtype),
],
dim=1,
)
if lora_b.shape[0] < v.shape[0]:
lora_b = torch.cat(
[
lora_b,
torch.zeros(v.shape[0] - lora_b.shape[0], lora_b.shape[1], device=device, dtype=v.dtype),
],
dim=0,
)
diff = strength * (lora_b @ lora_a)
new_state_dict[k] = v + diff
else:
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
if isinstance(path_or_state_dict, dict):
state_dict = path_or_state_dict
state_dict = {
k: v for k, v in path_or_state_dict.items()
} # copy a new one to avoid modifying the original one
else:
state_dict = load_state_dict_in_safetensors(path_or_state_dict)
unquantized_loras = {}
for k in state_dict.keys():
if not is_nunchaku_format(state_dict):
state_dict = to_nunchaku(state_dict, base_sd=self._quantized_part_sd)
unquantized_part_loras = {}
for k, v in list(state_dict.items()):
device = next(self.parameters()).device
if "transformer_blocks" not in k:
unquantized_loras[k] = state_dict[k]
for k in unquantized_loras.keys():
state_dict.pop(k)
unquantized_part_loras[k] = state_dict.pop(k).to(device)
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras
self._update_unquantized_part_lora_params(1)
self.unquantized_loras = unquantized_loras
if len(unquantized_loras) > 0:
if self.unquantized_state_dict is None:
unquantized_state_dict = self.state_dict()
self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()}
self.update_unquantized_lora_params(1)
quantized_part_vectors = {}
for k, v in list(state_dict.items()):
if v.ndim == 1:
quantized_part_vectors[k] = state_dict.pop(k)
if len(self._quantized_part_vectors) > 0 or len(quantized_part_vectors) > 0:
self._quantized_part_vectors = quantized_part_vectors
updated_vectors = fuse_vectors(quantized_part_vectors, self._quantized_part_sd, 1)
state_dict.update(updated_vectors)
# Get the vectors from the quantized part
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.loadDict(state_dict, True)
# This function can only be used with a single LoRA.
# For multiple LoRAs, please fuse the lora scale into the weights.
def set_lora_strength(self, strength: float = 1):
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setLoraScale(SVD_RANK, strength)
if len(self.unquantized_loras) > 0:
self.update_unquantized_lora_params(strength)
def set_attention_impl(self, impl: str):
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setAttentionImpl(impl)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module")
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
### Compatible with the original forward method
self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
self.single_transformer_blocks = nn.ModuleList([])
return self
if len(self._unquantized_part_loras) > 0:
self._update_unquantized_part_lora_params(strength)
if len(self._quantized_part_vectors) > 0:
vector_dict = fuse_vectors(self._quantized_part_vectors, self._quantized_part_sd, strength)
block.m.loadDict(vector_dict, True)
......@@ -2,13 +2,13 @@ import os
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers import SanaTransformer2DModel
from diffusers.configuration_utils import register_to_config
from huggingface_hub import utils
from safetensors.torch import load_file
from torch import nn
from torch.nn import functional as F
from .utils import NunchakuModelLoaderMixin
from .utils import get_precision, NunchakuModelLoaderMixin
from ..._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32
......@@ -30,7 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
skip_first_layer: Optional[bool] = False
skip_first_layer: Optional[bool] = False,
):
batch_size = hidden_states.shape[0]
......@@ -77,15 +77,15 @@ class NunchakuSanaTransformerBlocks(nn.Module):
)
def forward_layer_at(
self,
idx: int,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
self,
idx: int,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
):
batch_size = hidden_states.shape[0]
img_tokens = hidden_states.shape[1]
......@@ -132,62 +132,22 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: Optional[int] = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 20,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
) -> None:
# set num_layers to 0 to avoid creating transformer blocks
self.original_num_layers = num_layers
super(NunchakuSanaTransformer2DModel, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=0,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
caption_channels=caption_channels,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_bias=attention_bias,
sample_size=sample_size,
patch_size=patch_size,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
interpolation_scale=interpolation_scale,
)
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda")
pag_layers = kwargs.get("pag_layers", [])
precision = kwargs.get("precision", "int4")
assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
transformer.config["num_layers"] = transformer.original_num_layers
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
......
import os
import warnings
from typing import Any, Optional
import torch
from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from typing import Optional, Any
from torch import nn
from nunchaku.utils import ceil_divide
class NunchakuModelLoaderMixin:
@classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> tuple[nn.Module, str, str]:
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
......@@ -60,16 +63,13 @@ class NunchakuModelLoaderMixin:
**kwargs,
)
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(state_dict, strict=False)
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, transformer_block_path
return transformer, unquantized_part_path, transformer_block_path
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None:
if multiples <= 1:
return tensor
if tensor is None:
......@@ -77,8 +77,26 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
shape = list(tensor.shape)
if shape[dim] % multiples == 0:
return tensor
shape[dim] = ceil_div(shape[dim], multiples) * multiples
shape[dim] = ceil_divide(shape[dim], multiples) * multiples
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def get_precision(precision: str, device: str | torch.device, pretrained_model_name_or_path: str | None = None) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
......@@ -28,4 +28,4 @@ dependencies = [
"protobuf",
"huggingface_hub",
]
requires-python = ">=3.10, <3.13"
requires-python = ">=3.10"
......@@ -7,7 +7,7 @@ from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
......@@ -17,7 +17,7 @@ _CITATION = """\
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment