Commit 73939beb authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Merge github.com:mit-han-lab/nunchaku into migrate_tinychat

parents 91c7b53c bfd9aa3a
.gradio-container{max-width: 1200px !important}
.gradio-container {
max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
......@@ -276,8 +276,6 @@ with gr.Blocks(
outputs=[prompt_template],
api_name=False,
queue=False,
).then(
fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False
)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
......
......@@ -36,7 +36,7 @@ LORA_PATHS = {
},
}
SVDQ_LORA_PATH_FORMAT = "mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{name}.safetensors"
SVDQ_LORA_PATH_FORMAT = "mit-han-lab/svdquant-lora-collection/svdq-int4-flux.1-dev-{name}.safetensors"
SVDQ_LORA_PATHS = {
"Anime": SVDQ_LORA_PATH_FORMAT.format(name="anime"),
"GHIBSKY Illustration": SVDQ_LORA_PATH_FORMAT.format(name="ghibsky"),
......
......@@ -4,43 +4,43 @@
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/'>SANA-1600M</a> Demo
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank">
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>
&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
.gradio-container{max-width: 560px !important}
.gradio-container {
max-width: 640px !important;
margin: auto; /* Centers the element horizontally */
}
.gradio-container{max-width: 1200px !important}
.gradio-container {
max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
......@@ -4,13 +4,21 @@
## Installation
1. Install `nunchaku` following [README.md](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation).
2. Set up the dependencies for [ComfyUI](https://github.com/comfyanonymous/ComfyUI/tree/master) with the following commands:
2. Install dependencies needed to run custom ComfyUI nodes:
```shell
pip install git+https://github.com/asomoza/image_gen_aux.git
```
3. Set up the dependencies for [ComfyUI](https://github.com/comfyanonymous/ComfyUI/tree/master) with the following commands:
```shell
git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI
pip install -r requirements.txt
```
4. Install [ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) with the following commands then restart ComfyUI:
```shell
cd ComfyUI/custom_nodes
git clone https://github.com/ltdrdata/ComfyUI-Manager comfyui-manager
```
## Usage
......@@ -34,6 +42,7 @@ pip install -r requirements.txt
cd custom_nodes
ln -s ../../nunchaku/comfyui svdquant
```
* Install missing nodes (e.g., comfyui-inpainteasy) following [this tutorial](https://github.com/ltdrdata/ComfyUI-Manager?tab=readme-ov-file#support-of-missing-nodes-installation).
2. **Download Required Models**: Follow [this tutorial](https://comfyanonymous.github.io/ComfyUI_examples/flux/) and download the required models into the appropriate directories using the commands below:
......@@ -50,13 +59,13 @@ pip install -r requirements.txt
python main.py
```
4. **Select the SVDQuant Workflow**: Choose one of the SVDQuant workflows (`flux.1-dev-svdquant.json` or `flux.1-schnell-svdquant.json`) to get started.
4. **Select the SVDQuant Workflow**: Choose one of the SVDQuant workflows (`flux.1-dev-svdquant.json`, `flux.1-schnell-svdquant.json`, `flux.1-depth-svdquant.json`, `flux.1-canny-svdquant.json` or `flux.1-fill-svdquant.json`) to get started. For the flux.1 fill workflow, you can use the built-in MaskEditor tool to add mask on top of an image.
## SVDQuant Nodes
* **SVDQuant Flux DiT Loader**: A node for loading the FLUX diffusion model.
* `model_path`: Specifies the model location. If set to `mit-han-lab/svdq-int4-flux.1-schnell` or `mit-han-lab/svdq-int4-flux.1-dev`, the model will be automatically downloaded from our Hugging Face repository. Alternatively, you can manually download the model directory by running the following command:
* `model_path`: Specifies the model location. If set to `mit-han-lab/svdq-int4-flux.1-schnell`, `mit-han-lab/svdq-int4-flux.1-dev`, `mit-han-lab/svdq-int4-flux.1-canny-dev`, `mit-han-lab/svdq-int4-flux.1-fill-dev` or `mit-han-lab/svdq-int4-flux.1-depth-dev`, the model will be automatically downloaded from our Hugging Face repository. Alternatively, you can manually download the model directory by running the following command example:
```shell
huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev
......
......@@ -7,16 +7,17 @@ import comfy.sd
import folder_paths
import GPUtil
import torch
import numpy as np
from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat
from torch import nn
from transformers import T5EncoderModel
from image_gen_aux import DepthPreprocessor
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__()
......@@ -24,13 +25,25 @@ class ComfyUIFluxForwardWrapper(nn.Module):
self.dtype = next(model.parameters()).dtype
self.config = config
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
def forward(
self,
x,
timestep,
context,
y,
guidance,
control=None,
transformer_options={},
**kwargs,
):
assert control is None # for now
bs, c, h, w = x.shape
patch_size = self.config["patch_size"]
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
img = rearrange(
x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size
)
h_len = (h + (patch_size // 2)) // patch_size
w_len = (w + (patch_size // 2)) // patch_size
......@@ -54,21 +67,30 @@ class ComfyUIFluxForwardWrapper(nn.Module):
guidance=guidance if self.config["guidance_embed"] else None,
).sample
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
out = rearrange(
out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2
)[:, :, :h, :w]
return out
class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-int4-flux.1-schnell", "mit-han-lab/svdq-int4-flux.1-dev"]
model_paths = [
"mit-han-lab/svdq-int4-flux.1-schnell",
"mit-han-lab/svdq-int4-flux.1-dev",
"mit-han-lab/svdq-int4-flux.1-canny-dev",
"mit-han-lab/svdq-int4-flux.1-depth-dev",
"mit-han-lab/svdq-int4-flux.1-fill-dev",
]
prefix = "models/diffusion_models"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
......@@ -78,7 +100,14 @@ class SVDQuantFluxDiTLoader:
"model_path": (model_paths,),
"device_id": (
"INT",
{"default": 0, "min": 0, "max": ngpus, "step": 1, "display": "number", "lazy": True},
{
"default": 0,
"min": 0,
"max": ngpus,
"step": 1,
"display": "number",
"lazy": True,
},
),
}
}
......@@ -88,17 +117,20 @@ class SVDQuantFluxDiTLoader:
CATEGORY = "SVDQuant"
TITLE = "SVDQuant Flux DiT Loader"
def load_model(self, model_path: str, device_id: int, **kwargs) -> tuple[FluxTransformer2DModel]:
def load_model(
self, model_path: str, device_id: int, **kwargs
) -> tuple[FluxTransformer2DModel]:
device = f"cuda:{device_id}"
prefix = "models/diffusion_models"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
else:
model_path = model_path
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(device)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(
device
)
dit_config = {
"image_model": "flux",
"in_channels": 16,
"patch_size": 2,
"out_channels": 16,
"vec_in_dim": 768,
......@@ -111,21 +143,34 @@ class SVDQuantFluxDiTLoader:
"axes_dim": [16, 56, 56],
"theta": 10000,
"qkv_bias": True,
"guidance_embed": True,
"disable_unet_model_creation": True,
}
if "schnell" in model_path:
dit_config["guidance_embed"] = False
dit_config["in_channels"] = 16
model_config = FluxSchnell(dit_config)
elif "canny" in model_path or "depth" in model_path:
dit_config["in_channels"] = 32
model_config = Flux(dit_config)
elif "fill" in model_path:
dit_config["in_channels"] = 64
model_config = Flux(dit_config)
else:
assert "dev" in model_path
dit_config["guidance_embed"] = True
assert (
model_path == "mit-han-lab/svdq-int4-flux.1-dev"
), f"model {model_path} not supported"
dit_config["in_channels"] = 16
model_config = Flux(dit_config)
model_config.set_inference_dtype(torch.bfloat16, None)
model_config.custom_operations = None
model = model_config.get_model({})
model.diffusion_model = ComfyUIFluxForwardWrapper(transformer, config=dit_config)
model.diffusion_model = ComfyUIFluxForwardWrapper(
transformer, config=dit_config
)
model = comfy.model_patcher.ModelPatcher(model, device, device_id)
return (model,)
......@@ -157,7 +202,8 @@ class SVDQuantTextEncoderLoader:
[
folder
for folder in local_folders
if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
......@@ -168,7 +214,14 @@ class SVDQuantTextEncoderLoader:
"text_encoder2": (folder_paths.get_filename_list("text_encoders"),),
"t5_min_length": (
"INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
{
"default": 512,
"min": 256,
"max": 1024,
"step": 128,
"display": "number",
"lazy": True,
},
),
"t5_precision": (["BF16", "INT4"],),
"int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
......@@ -191,8 +244,12 @@ class SVDQuantTextEncoderLoader:
t5_precision: str,
int4_model: str,
):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2)
text_encoder_path1 = folder_paths.get_full_path_or_raise(
"text_encoders", text_encoder1
)
text_encoder_path2 = folder_paths.get_full_path_or_raise(
"text_encoders", text_encoder2
)
if model_type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
else:
......@@ -223,7 +280,9 @@ class SVDQuantTextEncoderLoader:
transformer = NunchakuT5EncoderModel.from_pretrained(model_path)
transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
clip.cond_stage_model.t5xxl.transformer = (
transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer
transformer.to(device=device, dtype=dtype)
if device.type == "cuda"
else transformer
)
return (clip,)
......@@ -239,11 +298,17 @@ class SVDQuantLoraLoader:
lora_name_list = [
"None",
*folder_paths.get_filename_list("loras"),
*[f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors" for n in hf_lora_names],
*[
f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors"
for n in hf_lora_names
],
]
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"model": (
"MODEL",
{"tooltip": "The diffusion model the LoRA will be applied to."},
),
"lora_name": (lora_name_list, {"tooltip": "The name of the LoRA."}),
"lora_strength": (
"FLOAT",
......@@ -292,8 +357,50 @@ class SVDQuantLoraLoader:
return (model,)
class DepthPreprocesser:
@classmethod
def INPUT_TYPES(s):
model_paths = ["LiheYoung/depth-anything-large-hf"]
prefix = "models/checkpoints"
local_folders = os.listdir(prefix)
local_folders = sorted(
[
folder
for folder in local_folders
if not folder.startswith(".")
and os.path.isdir(os.path.join(prefix, folder))
]
)
model_paths.extend(local_folders)
return {
"required": {
"image": ("IMAGE", {}),
"model_path": (
model_paths,
{"tooltip": "Name of the depth preprocesser model."},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "depth_preprocess"
CATEGORY = "Flux.1"
TITLE = "Flux.1 Depth Preprocessor"
def depth_preprocess(self, image, model_path):
prefix = "models/checkpoints"
if os.path.exists(os.path.join(prefix, model_path)):
model_path = os.path.join(prefix, model_path)
processor = DepthPreprocessor.from_pretrained(model_path)
np_image = np.asarray(image)
np_result = np.array(processor(np_image)[0].convert("RGB"))
out_tensor = torch.from_numpy(np_result.astype(np.float32) / 255.0).unsqueeze(0)
return (out_tensor,)
NODE_CLASS_MAPPINGS = {
"SVDQuantFluxDiTLoader": SVDQuantFluxDiTLoader,
"SVDQuantTextEncoderLoader": SVDQuantTextEncoderLoader,
"SVDQuantLoRALoader": SVDQuantLoraLoader,
"DepthPreprocesser": DepthPreprocesser
}
{
"last_node_id": 38,
"last_link_id": 76,
"nodes": [
{
"id": 3,
"type": "KSampler",
"pos": [
1290,
40
],
"size": [
315,
262
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 71,
"label": "model"
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 64,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 65,
"label": "negative"
},
{
"name": "latent_image",
"type": "LATENT",
"link": 66,
"label": "latent_image"
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
7
],
"slot_index": 0,
"label": "LATENT"
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
646386750200194,
"randomize",
20,
1,
"euler",
"normal",
1
]
},
{
"id": 35,
"type": "InstructPixToPixConditioning",
"pos": [
1040,
50
],
"size": [
235.1999969482422,
86
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "positive",
"type": "CONDITIONING",
"link": 67,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 68,
"label": "negative"
},
{
"name": "vae",
"type": "VAE",
"link": 69,
"label": "vae"
},
{
"name": "pixels",
"type": "IMAGE",
"link": 70,
"label": "pixels"
}
],
"outputs": [
{
"name": "positive",
"type": "CONDITIONING",
"links": [
64
],
"slot_index": 0,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"links": [
65
],
"slot_index": 1,
"label": "negative"
},
{
"name": "latent",
"type": "LATENT",
"links": [
66
],
"slot_index": 2,
"label": "latent"
}
],
"properties": {
"Node name for S&R": "InstructPixToPixConditioning"
},
"widgets_values": []
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1620,
40
],
"size": [
210,
46
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7,
"label": "samples"
},
{
"name": "vae",
"type": "VAE",
"link": 60,
"label": "vae"
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
9
],
"slot_index": 0,
"label": "IMAGE"
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1850,
40
],
"size": [
828.9535522460938,
893.8475341796875
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 9,
"label": "images"
}
],
"outputs": [],
"properties": {},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 32,
"type": "VAELoader",
"pos": [
1290,
350
],
"size": [
315,
58
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
60,
69
],
"slot_index": 0,
"label": "VAE"
}
],
"properties": {
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"ae.safetensors"
]
},
{
"id": 26,
"type": "FluxGuidance",
"pos": [
700,
50
],
"size": [
317.4000244140625,
58
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 41,
"label": "conditioning"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
67
],
"slot_index": 0,
"shape": 3,
"label": "CONDITIONING"
}
],
"properties": {
"Node name for S&R": "FluxGuidance"
},
"widgets_values": [
30
]
},
{
"id": 34,
"type": "DualCLIPLoader",
"pos": [
-80,
110
],
"size": [
315,
106
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
62,
63
],
"label": "CLIP"
}
],
"properties": {
"Node name for S&R": "DualCLIPLoader"
},
"widgets_values": [
"clip_l.safetensors",
"t5xxl_fp16.safetensors",
"flux",
"default"
]
},
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
307,
282
],
"size": [
425.27801513671875,
180.6060791015625
],
"flags": {
"collapsed": true
},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 63,
"label": "clip"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
68
],
"slot_index": 0,
"label": "CONDITIONING"
}
],
"title": "CLIP Text Encode (Negative Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
""
],
"color": "#322",
"bgcolor": "#533"
},
{
"id": 23,
"type": "CLIPTextEncode",
"pos": [
260,
50
],
"size": [
422.84503173828125,
164.31304931640625
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 62,
"label": "clip"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
41
],
"slot_index": 0,
"label": "CONDITIONING"
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 19,
"type": "PreviewImage",
"pos": [
1127.9403076171875,
554.3356323242188
],
"size": [
571.5869140625,
625.5296020507812
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 26,
"label": "images"
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 18,
"type": "Canny",
"pos": [
744.2684936523438,
566.853515625
],
"size": [
315,
82
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 76,
"label": "image"
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
26,
70
],
"slot_index": 0,
"shape": 3,
"label": "IMAGE"
}
],
"properties": {
"Node name for S&R": "Canny"
},
"widgets_values": [
0.15,
0.3
]
},
{
"id": 36,
"type": "SVDQuantFluxDiTLoader",
"pos": [
865.4989624023438,
-95.86973571777344
],
"size": [
315,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
71
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SVDQuantFluxDiTLoader"
},
"widgets_values": [
"mit-han-lab/svdq-int4-flux.1-canny-dev",
0
]
},
{
"id": 17,
"type": "LoadImage",
"pos": [
6.694743633270264,
562.3865966796875
],
"size": [
315,
314.0000305175781
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
75
],
"slot_index": 0,
"shape": 3,
"label": "IMAGE"
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3,
"label": "MASK"
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"robot.png",
"image"
]
},
{
"id": 38,
"type": "ImageScale",
"pos": [
379.69903564453125,
565.2651977539062
],
"size": [
315,
130
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 75
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
76
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "ImageScale"
},
"widgets_values": [
"nearest-exact",
1024,
1024,
"center"
]
}
],
"links": [
[
7,
3,
0,
8,
0,
"LATENT"
],
[
9,
8,
0,
9,
0,
"IMAGE"
],
[
26,
18,
0,
19,
0,
"IMAGE"
],
[
41,
23,
0,
26,
0,
"CONDITIONING"
],
[
60,
32,
0,
8,
1,
"VAE"
],
[
62,
34,
0,
23,
0,
"CLIP"
],
[
63,
34,
0,
7,
0,
"CLIP"
],
[
64,
35,
0,
3,
1,
"CONDITIONING"
],
[
65,
35,
1,
3,
2,
"CONDITIONING"
],
[
66,
35,
2,
3,
3,
"LATENT"
],
[
67,
26,
0,
35,
0,
"CONDITIONING"
],
[
68,
7,
0,
35,
1,
"CONDITIONING"
],
[
69,
32,
0,
35,
2,
"VAE"
],
[
70,
18,
0,
35,
3,
"IMAGE"
],
[
71,
36,
0,
3,
0,
"MODEL"
],
[
75,
17,
0,
38,
0,
"IMAGE"
],
[
76,
38,
0,
18,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.895430243255241,
"offset": [
203.26101803057463,
215.36536277004458
]
}
},
"version": 0.4
}
\ No newline at end of file
{
"last_node_id": 44,
"last_link_id": 85,
"nodes": [
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
307,
282
],
"size": [
425.27801513671875,
180.6060791015625
],
"flags": {
"collapsed": true
},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 63,
"label": "clip"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
68
],
"slot_index": 0,
"label": "CONDITIONING"
}
],
"title": "CLIP Text Encode (Negative Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
""
],
"color": "#322",
"bgcolor": "#533"
},
{
"id": 34,
"type": "DualCLIPLoader",
"pos": [
-238,
112
],
"size": [
315,
106
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
62,
63
],
"label": "CLIP"
}
],
"properties": {
"Node name for S&R": "DualCLIPLoader"
},
"widgets_values": [
"clip_l.safetensors",
"t5xxl_fp16.safetensors",
"flux",
"default"
]
},
{
"id": 26,
"type": "FluxGuidance",
"pos": [
621,
8
],
"size": [
317.4000244140625,
58
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 41,
"label": "conditioning"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
67
],
"slot_index": 0,
"shape": 3,
"label": "CONDITIONING"
}
],
"properties": {
"Node name for S&R": "FluxGuidance"
},
"widgets_values": [
10
]
},
{
"id": 32,
"type": "VAELoader",
"pos": [
656,
165
],
"size": [
315,
58
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
60,
69
],
"slot_index": 0,
"label": "VAE"
}
],
"properties": {
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"ae.safetensors"
]
},
{
"id": 3,
"type": "KSampler",
"pos": [
1280,
100
],
"size": [
315,
262
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 78,
"label": "model"
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 64,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 65,
"label": "negative"
},
{
"name": "latent_image",
"type": "LATENT",
"link": 73,
"label": "latent_image"
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
7
],
"slot_index": 0,
"label": "LATENT"
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
718322679777603,
"randomize",
20,
1,
"euler",
"normal",
1
]
},
{
"id": 35,
"type": "InstructPixToPixConditioning",
"pos": [
1008,
118
],
"size": [
235.1999969482422,
86
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "positive",
"type": "CONDITIONING",
"link": 67,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 68,
"label": "negative"
},
{
"name": "vae",
"type": "VAE",
"link": 69,
"label": "vae"
},
{
"name": "pixels",
"type": "IMAGE",
"link": 80,
"label": "pixels"
}
],
"outputs": [
{
"name": "positive",
"type": "CONDITIONING",
"links": [
64
],
"slot_index": 0,
"label": "positive"
},
{
"name": "negative",
"type": "CONDITIONING",
"links": [
65
],
"slot_index": 1,
"label": "negative"
},
{
"name": "latent",
"type": "LATENT",
"links": [
73
],
"slot_index": 2,
"label": "latent"
}
],
"properties": {
"Node name for S&R": "InstructPixToPixConditioning"
},
"widgets_values": []
},
{
"id": 39,
"type": "SVDQuantFluxDiTLoader",
"pos": [
707.80908203125,
-172.0343017578125
],
"size": [
315,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
78
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SVDQuantFluxDiTLoader"
},
"widgets_values": [
"mit-han-lab/svdq-int4-flux.1-depth-dev",
0
]
},
{
"id": 42,
"type": "ImageScale",
"pos": [
378.3890686035156,
472.7001953125
],
"size": [
315,
130
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 82
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
83
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "ImageScale"
},
"widgets_values": [
"nearest-exact",
1024,
1024,
"center"
]
},
{
"id": 17,
"type": "LoadImage",
"pos": [
30.604948043823242,
419.3930358886719
],
"size": [
315,
314.0000305175781
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
82
],
"slot_index": 0,
"shape": 3,
"label": "IMAGE"
},
{
"name": "MASK",
"type": "MASK",
"links": null,
"shape": 3,
"label": "MASK"
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"robot.png",
"image"
]
},
{
"id": 23,
"type": "CLIPTextEncode",
"pos": [
115,
-17
],
"size": [
422.84503173828125,
164.31304931640625
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 62,
"label": "clip"
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
41
],
"slot_index": 0,
"label": "CONDITIONING"
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 43,
"type": "PreviewImage",
"pos": [
1001.3873291015625,
432.09039306640625
],
"size": [
571.5869140625,
625.5296020507812
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 84,
"label": "images"
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 40,
"type": "DepthPreprocesser",
"pos": [
639.0159301757812,
350.06134033203125
],
"size": [
315,
58
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 83
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
80,
84
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "DepthPreprocesser"
},
"widgets_values": [
"LiheYoung/depth-anything-large-hf"
]
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1620,
98
],
"size": [
210,
46
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7,
"label": "samples"
},
{
"name": "vae",
"type": "VAE",
"link": 60,
"label": "vae"
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
85
],
"slot_index": 0,
"label": "IMAGE"
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 44,
"type": "SaveImage",
"pos": [
1912.7984619140625,
109.0069580078125
],
"size": [
828.9535522460938,
893.8475341796875
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 85,
"label": "images"
}
],
"outputs": [],
"properties": {},
"widgets_values": [
"ComfyUI"
]
}
],
"links": [
[
7,
3,
0,
8,
0,
"LATENT"
],
[
41,
23,
0,
26,
0,
"CONDITIONING"
],
[
60,
32,
0,
8,
1,
"VAE"
],
[
62,
34,
0,
23,
0,
"CLIP"
],
[
63,
34,
0,
7,
0,
"CLIP"
],
[
64,
35,
0,
3,
1,
"CONDITIONING"
],
[
65,
35,
1,
3,
2,
"CONDITIONING"
],
[
67,
26,
0,
35,
0,
"CONDITIONING"
],
[
68,
7,
0,
35,
1,
"CONDITIONING"
],
[
69,
32,
0,
35,
2,
"VAE"
],
[
73,
35,
2,
3,
3,
"LATENT"
],
[
78,
39,
0,
3,
0,
"MODEL"
],
[
80,
40,
0,
35,
3,
"IMAGE"
],
[
82,
17,
0,
42,
0,
"IMAGE"
],
[
83,
42,
0,
40,
0,
"IMAGE"
],
[
84,
40,
0,
43,
0,
"IMAGE"
],
[
85,
8,
0,
44,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.6115909044841502,
"offset": [
724.4911189218763,
518.3043483917891
]
}
},
"version": 0.4
}
\ No newline at end of file
{
"last_node_id": 58,
"last_link_id": 108,
"nodes": [
{
"id": 8,
"type": "VAEDecode",
"pos": [
1620,
98
],
"size": [
210,
46
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 7
},
{
"name": "vae",
"type": "VAE",
"link": 60
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
95
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 38,
"type": "InpaintModelConditioning",
"pos": [
952,
78
],
"size": [
302.4000244140625,
138
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "positive",
"type": "CONDITIONING",
"link": 80
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 81
},
{
"name": "vae",
"type": "VAE",
"link": 82
},
{
"name": "pixels",
"type": "IMAGE",
"link": 107
},
{
"name": "mask",
"type": "MASK",
"link": 108
}
],
"outputs": [
{
"name": "positive",
"type": "CONDITIONING",
"links": [
77
],
"slot_index": 0
},
{
"name": "negative",
"type": "CONDITIONING",
"links": [
78
],
"slot_index": 1
},
{
"name": "latent",
"type": "LATENT",
"links": [
88
],
"slot_index": 2
}
],
"properties": {
"Node name for S&R": "InpaintModelConditioning"
},
"widgets_values": [
false
]
},
{
"id": 3,
"type": "KSampler",
"pos": [
1280,
100
],
"size": [
315,
262
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 102
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 77
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 78
},
{
"name": "latent_image",
"type": "LATENT",
"link": 88
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
7
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
1078304127779394,
"randomize",
20,
1,
"euler",
"normal",
1
]
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1879,
90
],
"size": [
828.9535522460938,
893.8475341796875
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 95
}
],
"outputs": [],
"properties": {
"Node name for S&R": "SaveImage"
},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 26,
"type": "FluxGuidance",
"pos": [
596,
48
],
"size": [
317.4000244140625,
58
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 41
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
80
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "FluxGuidance"
},
"widgets_values": [
30
]
},
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
165,
267
],
"size": [
425.27801513671875,
180.6060791015625
],
"flags": {
"collapsed": true
},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 63
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
81
],
"slot_index": 0
}
],
"title": "CLIP Text Encode (Negative Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
""
],
"color": "#322",
"bgcolor": "#533"
},
{
"id": 34,
"type": "DualCLIPLoader",
"pos": [
-237,
76
],
"size": [
315,
106
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
62,
63
]
}
],
"properties": {
"Node name for S&R": "DualCLIPLoader"
},
"widgets_values": [
"clip_l.safetensors",
"t5xxl_fp16.safetensors",
"flux",
"default"
]
},
{
"id": 32,
"type": "VAELoader",
"pos": [
1303,
424
],
"size": [
315,
58
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
60,
82
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"ae.safetensors"
]
},
{
"id": 45,
"type": "SVDQuantFluxDiTLoader",
"pos": [
936.3029174804688,
-113.06819915771484
],
"size": [
315,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
102
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SVDQuantFluxDiTLoader"
},
"widgets_values": [
"mit-han-lab/svdq-int4-flux.1-fill-dev",
0
]
},
{
"id": 48,
"type": "Note",
"pos": [
466.9884033203125,
643.9080810546875
],
"size": [
314.99755859375,
117.98363494873047
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {
"text": ""
},
"widgets_values": [
"To add mask for fill inpainting, right click on the uploaded image and select \"Open in MaskEditor\". Use the brush tool to add masking and click save to continue."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 17,
"type": "LoadImage",
"pos": [
126.66505432128906,
460.53631591796875
],
"size": [
315,
314.0000305175781
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
105
],
"slot_index": 0,
"shape": 3
},
{
"name": "MASK",
"type": "MASK",
"links": [
106
],
"slot_index": 1,
"shape": 3
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"clipspace/clipspace-mask-123191.png [input]",
"image"
]
},
{
"id": 58,
"type": "ImageAndMaskResizeNode",
"pos": [
536.786865234375,
328.54388427734375
],
"size": [
315,
174
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 105
},
{
"name": "mask",
"type": "MASK",
"link": 106
}
],
"outputs": [
{
"name": "image",
"type": "IMAGE",
"links": [
107
],
"slot_index": 0
},
{
"name": "mask",
"type": "MASK",
"links": [
108
],
"slot_index": 1
}
],
"properties": {
"Node name for S&R": "ImageAndMaskResizeNode"
},
"widgets_values": [
1024,
1024,
"nearest-exact",
"center",
10
]
},
{
"id": 23,
"type": "CLIPTextEncode",
"pos": [
144,
-7
],
"size": [
422.84503173828125,
164.31304931640625
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 62
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
41
],
"slot_index": 0
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"A robot with a closed eye pink face is giving a presentation"
],
"color": "#232",
"bgcolor": "#353"
}
],
"links": [
[
7,
3,
0,
8,
0,
"LATENT"
],
[
41,
23,
0,
26,
0,
"CONDITIONING"
],
[
60,
32,
0,
8,
1,
"VAE"
],
[
62,
34,
0,
23,
0,
"CLIP"
],
[
63,
34,
0,
7,
0,
"CLIP"
],
[
77,
38,
0,
3,
1,
"CONDITIONING"
],
[
78,
38,
1,
3,
2,
"CONDITIONING"
],
[
80,
26,
0,
38,
0,
"CONDITIONING"
],
[
81,
7,
0,
38,
1,
"CONDITIONING"
],
[
82,
32,
0,
38,
2,
"VAE"
],
[
88,
38,
2,
3,
3,
"LATENT"
],
[
95,
8,
0,
9,
0,
"IMAGE"
],
[
102,
45,
0,
3,
0,
"MODEL"
],
[
105,
17,
0,
58,
0,
"IMAGE"
],
[
106,
17,
1,
58,
1,
"MASK"
],
[
107,
58,
0,
38,
3,
"IMAGE"
],
[
108,
58,
1,
38,
4,
"MASK"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.8390545288824038,
"offset": [
361.8437326514503,
242.3368651567008
]
},
"node_versions": {
"comfy-core": "0.3.14",
"comfyui-inpainteasy": "1.0.2"
}
},
"version": 0.4
}
\ No newline at end of file
......@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basket of several individual cartons of blueberries.",
prompt="A wooden basket of a cat.",
image=image,
mask_image=mask,
height=1024,
......
......@@ -8,11 +8,6 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world",
width=1024,
height=1024,
num_inference_steps=4,
guidance_scale=0,
generator=torch.Generator().manual_seed(2333),
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell-int4.png")
__version__ = "0.0.2beta3"
__version__ = "0.0.2beta4"
"""Convert LoRA weights to Nunchaku format."""
import argparse
import os
import typing as tp
import safetensors
import safetensors.torch
import torch
import tqdm
# region utilities
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def pad(
tensor: tp.Optional[torch.Tensor],
divisor: int | tp.Sequence[int],
dim: int | tp.Sequence[int],
fill_value: float | int = 0,
) -> torch.Tensor | None:
if isinstance(divisor, int):
if divisor <= 1:
return tensor
elif all(d <= 1 for d in divisor):
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if isinstance(dim, int):
assert isinstance(divisor, int)
shape[dim] = ceil_divide(shape[dim], divisor) * divisor
else:
if isinstance(divisor, int):
divisor = [divisor] * len(dim)
for d, div in zip(dim, divisor, strict=True):
shape[d] = ceil_divide(shape[d], div) * div
result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def update_state_dict(
lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
) -> dict[str, torch.Tensor]:
for rkey, value in rhs.items():
lkey = f"{prefix}.{rkey}" if prefix else rkey
assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
lhs[lkey] = value
return lhs
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
# endregion
def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size, k_pack_size = 2, 2
num_n_lanes, num_k_lanes = 8, 4
frag_n = n_pack_size * num_n_lanes * lane_n
frag_k = k_pack_size * num_k_lanes * lane_k
weight = pad(weight, divisor=(frag_n, frag_k), dim=(0, 1))
if down:
r, c = weight.shape
r_frags, c_frags = r // frag_n, c // frag_k
weight = weight.view(r_frags, frag_n, c_frags, frag_k).permute(2, 0, 1, 3)
else:
c, r = weight.shape
c_frags, r_frags = c // frag_n, r // frag_k
weight = weight.view(c_frags, frag_n, r_frags, frag_k).permute(0, 2, 1, 3)
weight = weight.reshape(c_frags, r_frags, n_pack_size, num_n_lanes, k_pack_size, num_k_lanes, lane_k)
weight = weight.permute(0, 1, 3, 5, 2, 4, 6).contiguous()
return weight.view(c, r)
def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size, k_pack_size = 2, 2
num_n_lanes, num_k_lanes = 8, 4
frag_n = n_pack_size * num_n_lanes * lane_n
frag_k = k_pack_size * num_k_lanes * lane_k
if down:
r_frags, c_frags = r // frag_n, c // frag_k
else:
c_frags, r_frags = c // frag_n, r // frag_k
weight = weight.view(c_frags, r_frags, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, lane_k)
weight = weight.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
weight = weight.view(c_frags, r_frags, frag_n, frag_k)
if down:
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
else:
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
return weight
def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
c, r = lora_up.shape
assert c % splits == 0
return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
local_name_map: dict[str, str | list[str]],
convert_map: dict[str, str],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
print(f"Converting LoRA branch for block {candidate_block_name}...")
converted: dict[str, torch.Tensor] = {}
for converted_local_name, candidate_local_names in tqdm.tqdm(
local_name_map.items(), desc=f"Converting {candidate_block_name}", dynamic_ncols=True
):
if isinstance(candidate_local_names, str):
candidate_local_names = [candidate_local_names]
# region original LoRA
orig_lora = (
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
)
if orig_lora[0] is None or orig_lora[1] is None:
assert orig_lora[0] is None and orig_lora[1] is None
orig_lora = None
else:
assert orig_lora[0] is not None and orig_lora[1] is not None
orig_lora = (
unpack_lowrank_weight(orig_lora[0], down=True),
unpack_lowrank_weight(orig_lora[1], down=False),
)
print(f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})")
# endregion
# region extra LoRA
extra_lora = [
(
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
)
for candidate_local_name in candidate_local_names
]
if any(lora[0] is not None or lora[1] is not None for lora in extra_lora):
# merge extra LoRAs into one LoRA
if len(extra_lora) > 1:
first_lora = None
for lora in extra_lora:
if lora[0] is not None:
assert lora[1] is not None
first_lora = lora
break
assert first_lora is not None
for lora_index in range(len(extra_lora)):
if extra_lora[lora_index][0] is None:
assert extra_lora[lora_index][1] is None
extra_lora[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
if all(lora[0].equal(extra_lora[0][0]) for lora in extra_lora):
# if all extra LoRAs have the same lora_down, use it
extra_lora_down = extra_lora[0][0]
extra_lora_up = torch.cat([lora[1] for lora in extra_lora], dim=0)
else:
extra_lora_down = torch.cat([lora[0] for lora in extra_lora], dim=0)
extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora)
extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora)
assert extra_lora_up_r == extra_lora_down.shape[0]
extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
c, r = 0, 0
for lora in extra_lora:
c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
extra_lora_up[c:c_next, r:r_next] = lora[1]
c, r = c_next, r_next
else:
extra_lora_down, extra_lora_up = extra_lora[0]
extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
print(f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})")
else:
extra_lora = None
# endregion
# region merge LoRA
if orig_lora is None:
if extra_lora is None:
lora = None
else:
print(" - Using extra LoRA")
lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
elif extra_lora is None:
print(" - Using original LoRA")
lora = orig_lora
else:
lora = (
torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0),
torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1),
)
print(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
# endregion
if lora is not None:
if convert_map[converted_local_name] == "adanorm_single":
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=3),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "adanorm_zero":
update_state_dict(
converted,
{
"lora_down": lora[0],
"lora_up": reorder_adanorm_lora_up(lora[1], splits=6),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "linear":
update_state_dict(
converted,
{
"lora_down": pack_lowrank_weight(lora[0], down=True),
"lora_up": pack_lowrank_weight(lora[1], down=False),
},
prefix=converted_local_name,
)
return converted
def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
assert lora_down.shape[1] == n1 + n2
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm.linear": "norm.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "proj_mlp",
"mlp_fc2": "proj_out.linears.1",
},
convert_map={
"norm.linear": "adanorm_single",
"qkv_proj": "linear",
"out_proj": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
},
default_dtype=default_dtype,
)
def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm1.linear": "norm1.linear",
"norm1_context.linear": "norm1_context.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"norm_added_q": "attn.norm_added_q",
"norm_added_k": "attn.norm_added_k",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_add_out",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": "ff.net.2",
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": "ff_context.net.2",
},
convert_map={
"norm1.linear": "adanorm_zero",
"norm1_context.linear": "adanorm_zero",
"qkv_proj": "linear",
"qkv_proj_context": "linear",
"out_proj": "linear",
"out_proj_context": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
"mlp_context_fc1": "linear",
"mlp_context_fc2": "linear",
},
default_dtype=default_dtype,
)
def convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
block_names.add(".".join(param_name.split(".")[:2]))
block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
print(f"Converting {len(block_names)} transformer blocks...")
converted: dict[str, torch.Tensor] = {}
for block_name in block_names:
if block_name.startswith("transformer_blocks"):
convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
else:
convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
update_state_dict(
converted,
convert_fn(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=block_name,
candidate_block_name=block_name,
default_dtype=default_dtype,
),
prefix=block_name,
)
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantized model safetensor file")
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
assert args.lora_path is not None, "LoRA name or path must be provided"
lora_name = args.lora_path.rstrip(os.sep).split(os.sep)[-1].replace(".safetensors", "")
print(f"Lora name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path, filter_prefix="transformer.")
converted = convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
safetensors.torch.save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment