"examples/vscode:/vscode.git/clone" did not exist on "2daba9764f7f0e82610dcb249293e70a95bc2c6e"
Unverified Commit 07f07563 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

chore: release v0.3.1

parents 7214300d ad92b16a
...@@ -125,7 +125,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT ...@@ -125,7 +125,7 @@ If you're using a Blackwell GPU (e.g., 50-series GPUs), install a wheel with PyT
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# For gradio demos # For gradio demos
pip install peft opencv-python gradio spaces GPUtil pip install peft opencv-python gradio spaces
``` ```
To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch>=2.7 with CUDA>=12.8. The installation command can be: To enable NVFP4 on Blackwell GPUs (e.g., 50-series GPUs), please install nightly PyTorch>=2.7 with CUDA>=12.8. The installation command can be:
......
...@@ -122,7 +122,7 @@ pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0. ...@@ -122,7 +122,7 @@ pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.
pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
# Gradio演示依赖 # Gradio演示依赖
pip install peft opencv-python gradio spaces GPUtil pip install peft opencv-python gradio spaces
``` ```
Blackwell用户需安装PyTorch>=2.7, CUDA>=12.8: Blackwell用户需安装PyTorch>=2.7, CUDA>=12.8:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-{model_name}-dev Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Blog] alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h4>Quantization Library: <h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
...@@ -4,7 +4,6 @@ import random ...@@ -4,7 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import torch import torch
from controlnet_aux import CannyDetector from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline from diffusers import FluxControlPipeline
...@@ -54,12 +53,16 @@ if args.precision == "bf16": ...@@ -54,12 +53,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_class.from_pretrained( pipeline = pipeline_class.from_pretrained(
...@@ -125,11 +128,12 @@ def run( ...@@ -125,11 +128,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-fill-dev Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Blog] alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h4>Quantization Library: <h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp; <a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp; Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
......
...@@ -4,7 +4,6 @@ import random ...@@ -4,7 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import torch import torch
from diffusers import FluxFillPipeline from diffusers import FluxFillPipeline
from PIL import Image from PIL import Image
...@@ -26,12 +25,16 @@ if args.precision == "bf16": ...@@ -26,12 +25,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-int4_r32-flux.1-fill-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained( pipeline = FluxFillPipeline.from_pretrained(
...@@ -97,11 +100,12 @@ def run( ...@@ -97,11 +100,12 @@ def run(
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
# Nunchaku INT4 FLUX.1 Redux Demo # Nunchaku INT4 FLUX.1 Redux Demo
![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.jpg) ![demo](https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/app/flux.1/redux/assets/demo.png)
This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run: This interactive Gradio application allows you to interactively generate image variations. The base model is [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). We use [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) to preprocess the image before inputting it into Flux.1-dev. To launch the application, run:
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <!-- Logo Row -->
alt="logo" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
style="height: 40px; width: auto; display: block; margin: auto;"/> <a href="https://github.com/mit-han-lab/nunchaku">
INT4 FLUX.1-redux-dev Demo <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
</h1> alt="nunchaku logo"
<h2> style="height: 150px; width: auto;"/>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/blog/svdquant'> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Blog] alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp; <!-- Title -->
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp; <h1 style="margin-top: 0;">INT4 FLUX.1-redux-dev Demo</h1>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css'); @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important} .gradio-container {
h1{text-align:center} max-width: 1200px !important;
margin: auto; /* Centers the element horizontally */
}
h1 {
text-align: center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j { .wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none; display: none;
...@@ -22,8 +28,13 @@ h1{text-align:center} ...@@ -22,8 +28,13 @@ h1{text-align:center}
#accessibility { #accessibility {
text-align: center; /* Center-aligns the text */ text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */ margin: auto; /* Centers the element horizontally */
} }
#random_seed {height: 71px;} #random_seed {
#run_button {height: 87px;} height: 71px;
}
#run_button {
height: 87px;
}
...@@ -4,10 +4,6 @@ import random ...@@ -4,10 +4,6 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import torch import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image from PIL import Image
...@@ -16,6 +12,9 @@ from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED ...@@ -16,6 +12,9 @@ from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
# import gradio last to avoid conflicts with other imports
import gradio as gr # noqa: isort: skip
args = get_args() args = get_args()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
...@@ -29,7 +28,9 @@ if args.precision == "bf16": ...@@ -29,7 +28,9 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
text_encoder=None, text_encoder=None,
...@@ -79,11 +80,12 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu ...@@ -79,11 +80,12 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
INT4 FLUX.1-schnell Sketch-to-Image Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Blog] alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h4>Quantization Library: <h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo" target="_blank">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py # Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os import os
import random import random
import tempfile
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import numpy as np import numpy as np
import torch import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
...@@ -33,12 +31,16 @@ if args.precision == "bf16": ...@@ -33,12 +31,16 @@ if args.precision == "bf16":
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPix2pixTurboPipeline.from_pretrained( pipeline = FluxPix2pixTurboPipeline.from_pretrained(
...@@ -55,14 +57,6 @@ else: ...@@ -55,14 +57,6 @@ else:
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]: def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}") print(f"Prompt: {prompt}")
...@@ -116,11 +110,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -116,11 +110,12 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() # Get the GPU properties
if len(gpus) > 0: if torch.cuda.device_count() > 0:
gpu = gpus[0] gpu_properties = torch.cuda.get_device_properties(0)
memory = gpu.memoryTotal / 1024 gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
...@@ -170,7 +165,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -170,7 +165,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
with gr.Row(): with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button") run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row(): with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox( prompt_template = gr.Textbox(
...@@ -207,7 +201,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -207,7 +201,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
) )
latency_result = gr.Text(label="Inference Latency", show_label=True) latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions") gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching") gr.Markdown("**2**. Start sketching")
...@@ -235,8 +228,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -235,8 +228,6 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
api_name=False, api_name=False,
) )
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
......
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <!-- Logo Row -->
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg" <div style="display: flex; justify-content: center; align-items: center; gap: 10px; margin-bottom: 10px;">
alt="logo" <a href="https://github.com/mit-han-lab/nunchaku">
style="height: 40px; width: auto; display: block; margin: auto;"/> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/nunchaku.svg"
FLUX.1-{model} Demo alt="nunchaku logo"
</h1> style="height: 150px; width: auto;"/>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me' target="_blank">Muyang Li*</a>,
<a href='https://yujunlin.com' target="_blank">Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang' target="_blank">Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/' target="_blank">Tianle Cai</a>,
<a href='https://xiuyuli.com' target="_blank">Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX' target="_blank">Junxian Guo</a>,
<a href='https://xieenze.github.io' target="_blank">Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/' target="_blank">Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/' target="_blank">Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan' target="_blank">Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007" target="_blank">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku' target="_blank">
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant' target="_blank">
[Website]
</a> </a>
&nbsp; <a href="https://hanlab.mit.edu/projects/svdquant">
<a href='https://hanlab.mit.edu/blog/svdquant' target="_blank"> <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
[Blog] alt="svdquant logo"
style="height: 40px; width: auto;"/>
</a> </a>
</div> </div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor' target="_blank">DeepCompressor</a> <!-- Title -->
&nbsp; <h1 style="margin-top: 0;">FLUX.1-{model} Demo</h1>
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku' target="_blank">Nunchaku</a>
</h4> <!-- Device Info -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
<!-- Notice -->
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice} {notice}
</div> </div>
<!-- Count Info -->
{count_info} {count_info}
</div> </div>
</div> </div>
...@@ -5,12 +5,11 @@ import random ...@@ -5,12 +5,11 @@ import random
import time import time
from datetime import datetime from datetime import datetime
import GPUtil
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, LORA_PATHS, MAX_SEED, PROMPT_TEMPLATES
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
...@@ -98,7 +97,9 @@ def generate( ...@@ -98,7 +97,9 @@ def generate(
else: else:
assert precision == "int4" assert precision == "int4"
if lora_name != "None": if lora_name != "None":
pipeline.transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name]) lora_path = LORA_PATHS[lora_name]
lora_path = os.path.join(lora_path["name_or_path"], lora_path["weight_name"])
pipeline.transformer.update_lora_params(lora_path)
pipeline.transformer.set_lora_strength(lora_weight) pipeline.transformer.set_lora_strength(lora_weight)
else: else:
pipeline.transformer.set_lora_strength(0) pipeline.transformer.set_lora_strength(0)
...@@ -157,11 +158,13 @@ def generate( ...@@ -157,11 +158,13 @@ def generate(
with open("./assets/description.html", "r") as f: with open("./assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0: # Get the GPU properties
gpu = gpus[0] if torch.cuda.device_count() > 0:
memory = gpu.memoryTotal / 1024 gpu_properties = torch.cuda.get_device_properties(0)
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." gpu_memory = gpu_properties.total_memory / (1024**3) # Convert to GiB
gpu_name = torch.cuda.get_device_name(0)
device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
......
...@@ -28,17 +28,21 @@ def get_pipeline( ...@@ -28,17 +28,21 @@ def get_pipeline(
if precision in ["int4", "fp4"]: if precision in ["int4", "fp4"]:
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices" assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
)
else: else:
assert precision == "fp4" assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4" "mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if use_qencoder: if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else: else:
assert precision == "bf16" assert precision == "bf16"
...@@ -47,7 +51,9 @@ def get_pipeline( ...@@ -47,7 +51,9 @@ def get_pipeline(
) )
elif model_name == "dev": elif model_name == "dev":
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
)
if lora_name not in ["All", "None"]: if lora_name not in ["All", "None"]:
transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name]) transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
transformer.set_lora_strength(lora_weight) transformer.set_lora_strength(lora_weight)
...@@ -55,7 +61,9 @@ def get_pipeline( ...@@ -55,7 +61,9 @@ def get_pipeline(
if use_qencoder: if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True,
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
__version__ = "0.3.0" __version__ = "0.3.1"
...@@ -42,7 +42,6 @@ public: ...@@ -42,7 +42,6 @@ public:
if (net) { if (net) {
pybind11::object cb = residual_callback; pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor { net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x); torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x); pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>(); torch::Tensor torch_y = result.cast<torch::Tensor>();
...@@ -143,9 +142,17 @@ public: ...@@ -143,9 +142,17 @@ public:
temb = temb.contiguous(); temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous(); rotary_emb_single = rotary_emb_single.contiguous();
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->loadLazyParams();
}
Tensor result = net->single_transformer_blocks.at(idx)->forward( Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single)); from_torch(hidden_states), from_torch(temb), from_torch(rotary_emb_single));
if (net->isOffloadEnabled()) {
net->single_transformer_blocks.at(idx)->releaseLazyParams();
}
hidden_states = to_torch(result); hidden_states = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
......
...@@ -5,12 +5,19 @@ import torch ...@@ -5,12 +5,19 @@ import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from .diffusers_converter import to_diffusers from .diffusers_converter import to_diffusers
from .utils import is_nunchaku_format from .utils import is_nunchaku_format, load_state_dict_in_safetensors
def compose_lora( def compose_lora(
loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if len(loras) == 1:
if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
if isinstance(loras[0][0], str):
return load_state_dict_in_safetensors(loras[0][0], device="cpu")
else:
return loras[0][0]
composed = {} composed = {}
for lora, strength in loras: for lora, strength in loras:
assert not is_nunchaku_format(lora) assert not is_nunchaku_format(lora)
......
import argparse import argparse
import logging
import os import os
import warnings
import torch import torch
from diffusers.loaders import FluxLoraLoaderMixin from diffusers.loaders import FluxLoraLoaderMixin
...@@ -9,6 +9,52 @@ from safetensors.torch import save_file ...@@ -9,6 +9,52 @@ from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors from .utils import load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# first check if the state_dict is in the kohya format
# like: https://civitai.com/models/1118358?modelVersionId=1256866
if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
return state_dict
else:
new_state_dict = {}
for k, v in state_dict.items():
new_k = k.replace("lora_transformer_", "transformer.")
new_k = new_k.replace("norm_out_", "norm_out.")
new_k = new_k.replace("time_text_embed_", "time_text_embed.")
new_k = new_k.replace("guidance_embedder_", "guidance_embedder.")
new_k = new_k.replace("text_embedder_", "text_embedder.")
new_k = new_k.replace("timestep_embedder_", "timestep_embedder.")
new_k = new_k.replace("single_transformer_blocks_", "single_transformer_blocks.")
new_k = new_k.replace("_attn_", ".attn.")
new_k = new_k.replace("_norm_linear.", ".norm.linear.")
new_k = new_k.replace("_proj_mlp.", ".proj_mlp.")
new_k = new_k.replace("_proj_out.", ".proj_out.")
new_k = new_k.replace("transformer_blocks_", "transformer_blocks.")
new_k = new_k.replace("to_out_0.", "to_out.0.")
new_k = new_k.replace("_ff_context_net_0_proj.", ".ff_context.net.0.proj.")
new_k = new_k.replace("_ff_context_net_2.", ".ff_context.net.2.")
new_k = new_k.replace("_ff_net_0_proj.", ".ff.net.0.proj.")
new_k = new_k.replace("_ff_net_2.", ".ff.net.2.")
new_k = new_k.replace("_norm1_context_linear.", ".norm1_context.linear.")
new_k = new_k.replace("_norm1_linear.", ".norm1.linear.")
new_k = new_k.replace(".lora_down.", ".lora_A.")
new_k = new_k.replace(".lora_up.", ".lora_B.")
new_state_dict[new_k] = v
return new_state_dict
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]: def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
if isinstance(input_lora, str): if isinstance(input_lora, str):
...@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N ...@@ -16,6 +62,8 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
else: else:
tensors = {k: v for k, v in input_lora.items()} tensors = {k: v for k, v in input_lora.items()}
tensors = handle_kohya_lora(tensors)
### convert the FP8 tensors to BF16 ### convert the FP8 tensors to BF16
for k, v in tensors.items(): for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]: if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
...@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N ...@@ -25,7 +73,14 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
new_tensors = convert_unet_state_dict_to_peft(new_tensors) new_tensors = convert_unet_state_dict_to_peft(new_tensors)
if alphas is not None and len(alphas) > 0: if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.") for k, v in alphas.items():
key_A = k.replace(".alpha", ".lora_A.weight")
key_B = k.replace(".alpha", ".lora_B.weight")
assert key_A in new_tensors, f"Key {key_A} not found in new tensors."
assert key_B in new_tensors, f"Key {key_B} not found in new tensors."
rank = new_tensors[key_A].shape[0]
assert new_tensors[key_B].shape[1] == rank, f"Rank mismatch for {key_B}."
new_tensors[key_A] = new_tensors[key_A] * v / rank
if output_path is not None: if output_path is not None:
output_dir = os.path.dirname(os.path.abspath(output_path)) output_dir = os.path.dirname(os.path.abspath(output_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment