Unverified Commit f060b8da authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Major] Release v0.1.4

Support 4-bit text encoder and per-layer CPU offloading, reducing FLUX's minimum memory requirement to just 4 GiB while maintaining a 2–3× speedup. Fix various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
parents f549dfc6 873a35be
...@@ -3,7 +3,7 @@ from controlnet_aux import CannyDetector ...@@ -3,7 +3,7 @@ from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline from diffusers import FluxControlPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
......
...@@ -3,7 +3,7 @@ from diffusers import FluxControlPipeline ...@@ -3,7 +3,7 @@ from diffusers import FluxControlPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor from image_gen_aux import DepthPreprocessor
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
......
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", text_encoder_2=text_encoder_2, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from diffusers import FluxFillPipeline from diffusers import FluxFillPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png") image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png") mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=True
) # set offload to False if you want to disable offloading
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.enable_sequential_cpu_offload() # remove this line if you want to disable the CPU offloading
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
text_encoder_2=text_encoder_2,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save("flux.1-schnell.png")
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
import torch import torch
from diffusers import SanaPipeline from diffusers import SanaPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
......
import torch import torch
from diffusers import SanaPAGPipeline from diffusers import SanaPAGPipeline
from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
......
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__version__ = "0.1.3" __version__ = "0.1.4"
...@@ -9,9 +9,12 @@ ...@@ -9,9 +9,12 @@
class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder { class QuantizedFluxModel : public ModuleWrapper<FluxModel> { // : public torch::CustomClassHolder {
public: public:
void init(bool use_fp4, bool bf16, int8_t deviceId) { void init(bool use_fp4, bool offload, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel"); spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(use_fp4, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId)); if (offload) {
spdlog::info("Layer offloading enabled");
}
net = std::make_unique<FluxModel>(use_fp4, offload, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
} }
torch::Tensor forward( torch::Tensor forward(
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "interop/torch.h" #include "interop/torch.h"
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "kernels/awq/gemv_awq.h" #include "kernels/awq/gemv_awq.h"
#include "kernels/awq/gemm_awq.h"
namespace nunchaku::ops { namespace nunchaku::ops {
...@@ -71,7 +72,7 @@ namespace nunchaku::ops { ...@@ -71,7 +72,7 @@ namespace nunchaku::ops {
alpha, alpha,
getTensor(wcscales) getTensor(wcscales)
); );
Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
} }
torch::Tensor gemv_awq( torch::Tensor gemv_awq(
...@@ -96,8 +97,31 @@ namespace nunchaku::ops { ...@@ -96,8 +97,31 @@ namespace nunchaku::ops {
); );
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
return output; return output;
} }
torch::Tensor gemm_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros)
{
Tensor result = ::awq_gemm_forward_cuda(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous())
);
// TODO: allocate output in torch and use from_torch instead (to_torch needs an extra copy)
torch::Tensor output = to_torch(result);
// Tensor::synchronizeDevice();
return output;
}
}; };
\ No newline at end of file
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "sana.h" #include "sana.h"
#include "ops.h" #include "ops.h"
#include "utils.h" #include "utils.h"
#include <torch/extension.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -12,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -12,6 +13,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
py::arg("use_fp4"), py::arg("use_fp4"),
py::arg("offload"),
py::arg("bf16"), py::arg("bf16"),
py::arg("deviceId") py::arg("deviceId")
) )
...@@ -72,7 +74,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -72,7 +74,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
; ;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
; ;
......
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers
...@@ -362,6 +362,11 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -362,6 +362,11 @@ def convert_to_nunchaku_flux_lowrank_dict(
else: else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.") extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
unquantized_lora_dict = {}
for k in list(extra_lora_dict.keys()):
if "transformer_blocks" not in k:
unquantized_lora_dict[k] = extra_lora_dict.pop(k)
for k in extra_lora_dict.keys(): for k in extra_lora_dict.keys():
fc1_k = k fc1_k = k
if "ff.net.0.proj" in k: if "ff.net.0.proj" in k:
...@@ -408,4 +413,5 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -408,4 +413,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
prefix=block_name, prefix=block_name,
) )
converted.update(unquantized_lora_dict)
return converted return converted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment