Commit 7bb83833 authored by April Hu's avatar April Hu
Browse files

Merge branch 'main' of github.com:Aprilhuu/nunchaku into main

parents 35a4d011 420ad33d
...@@ -82,7 +82,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -82,7 +82,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
image_type="sketch", image_type="sketch",
alpha=sketch_guidance, alpha=sketch_guidance,
prompt=prompt, prompt=prompt,
generator=torch.Generator().manual_seed(int(seed)), generator=torch.Generator().manual_seed(seed),
).images[0] ).images[0]
latency = time.time() - start_time latency = time.time() - start_time
...@@ -229,7 +229,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De ...@@ -229,7 +229,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
outputs=[prompt_template], outputs=[prompt_template],
api_name=False, api_name=False,
queue=False, queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) )
gr.on( gr.on(
triggers=[prompt.submit, run_button.click, canvas.change], triggers=[prompt.submit, run_button.click, canvas.change],
fn=run, fn=run,
...@@ -244,4 +244,4 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De ...@@ -244,4 +244,4 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
if __name__ == "__main__": if __name__ == "__main__":
demo.queue().launch(debug=True, share=True) demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
...@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace: ...@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -2,6 +2,12 @@ ...@@ -2,6 +2,12 @@
## Text-to-Image Gradio Demo ## Text-to-Image Gradio Demo
![demo](./assets/demo.jpg)
This interactive Gradio application can generate an image based on your provided text prompt. The base model can be either [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) or [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
To launch the application, simply run:
```shell ```shell
python run_gradio.py python run_gradio.py
``` ```
......
...@@ -35,6 +35,7 @@ def get_args() -> argparse.Namespace: ...@@ -35,6 +35,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
return parser.parse_args() return parser.parse_args()
...@@ -282,4 +283,4 @@ with gr.Blocks( ...@@ -282,4 +283,4 @@ with gr.Blocks(
if __name__ == "__main__": if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True) demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True, root_path=args.gradio_root_path)
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
## Text-to-Image Gradio Demo ## Text-to-Image Gradio Demo
![demo](./assets/demo.jpg)
This interactive Gradio application can generate an image based on your provided text prompt. The base model is [SANA-1.6B](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers).
```shell ```shell
python run_gradio.py python run_gradio.py
``` ```
......
...@@ -6,9 +6,6 @@ import time ...@@ -6,9 +6,6 @@ import time
from datetime import datetime from datetime import datetime
import GPUtil import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import spaces import spaces
import torch import torch
...@@ -16,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker ...@@ -16,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED from vars import EXAMPLES, MAX_SEED
# import gradio last to avoid conflicts with other imports
import gradio as gr
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -31,6 +31,7 @@ def get_args() -> argparse.Namespace: ...@@ -31,6 +31,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses") parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
return parser.parse_args() return parser.parse_args()
...@@ -205,4 +206,4 @@ with gr.Blocks( ...@@ -205,4 +206,4 @@ with gr.Blocks(
if __name__ == "__main__": if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True) demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True, root_path=args.gradio_root_path)
...@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained( ...@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png") image.save("flux.1-dev-int4.png")
...@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained( ...@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
image = pipe( image = pipe(
prompt="A wooden basked of several individual cartons of blueberries.", prompt="A wooden basket of several individual cartons of blueberries.",
image=image, image=image,
mask_image=mask, mask_image=mask,
height=1024, height=1024,
...@@ -21,4 +21,4 @@ image = pipe( ...@@ -21,4 +21,4 @@ image = pipe(
num_inference_steps=50, num_inference_steps=50,
max_sequence_length=512, max_sequence_length=512,
).images[0] ).images[0]
image.save("flux.1-fill-dev-int4.png") image.save("flux.1-fill-dev.png")
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
...@@ -7,5 +7,12 @@ transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-i ...@@ -7,5 +7,12 @@ transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-i
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline(
image.save("flux.1-schnell.png") "A cat holding a sign that says hello world",
width=1024,
height=1024,
num_inference_steps=4,
guidance_scale=0,
generator=torch.Generator().manual_seed(2333),
).images[0]
image.save("flux.1-schnell-int4.png")
__version__ = "0.0.2beta2" __version__ = "0.0.2beta3"
...@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download, utils ...@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download, utils
from packaging.version import Version from packaging.version import Version
from torch import nn from torch import nn
from .utils import NunchakuModelLoaderMixin from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils from .._C import QuantizedFluxModel, utils as cutils
SVD_RANK = 32 SVD_RANK = 32
...@@ -52,6 +52,10 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -52,6 +52,10 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype) rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype) rotary_emb_single = image_rotary_emb # .to(self.dtype)
rotary_emb_txt = pad_tensor(rotary_emb_txt, 256, 1)
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1)
hidden_states = self.m.forward( hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
) )
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from diffusers import __version__ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from typing import Optional, Any
class NunchakuModelLoaderMixin: class NunchakuModelLoaderMixin:
...@@ -64,3 +65,20 @@ class NunchakuModelLoaderMixin: ...@@ -64,3 +65,20 @@ class NunchakuModelLoaderMixin:
transformer.load_state_dict(state_dict, strict=False) transformer.load_state_dict(state_dict, strict=False)
return transformer, transformer_block_path return transformer, transformer_block_path
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
if multiples <= 1:
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if shape[dim] % multiples == 0:
return tensor
shape[dim] = ceil_div(shape[dim], multiples) * multiples
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
\ No newline at end of file
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import setuptools import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
def build_extensions(self): def build_extensions(self):
for ext in self.extensions: for ext in self.extensions:
...@@ -17,6 +18,7 @@ class CustomBuildExtension(BuildExtension): ...@@ -17,6 +18,7 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions() super().build_extensions()
if __name__ == "__main__": if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read() fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1]) version = eval(fp.strip().split()[-1])
...@@ -53,8 +55,12 @@ if __name__ == "__main__": ...@@ -53,8 +55,12 @@ if __name__ == "__main__":
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1", "-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
"-gencode", "arch=compute_86,code=sm_86", "-gencode",
"-gencode", "arch=compute_89,code=sm_89", "arch=compute_86,code=sm_86",
"-gencode",
"arch=compute_89,code=sm_89",
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g", "-g",
"-std=c++20", "-std=c++20",
"-UNDEBUG", "-UNDEBUG",
...@@ -76,9 +82,7 @@ if __name__ == "__main__": ...@@ -76,9 +82,7 @@ if __name__ == "__main__":
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
] ]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487 # https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
NVCC_MSVC_FLAGS = [ NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
"-Xcompiler", "/Zc:__cplusplus"
]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
name="nunchaku._C", name="nunchaku._C",
...@@ -97,8 +101,12 @@ if __name__ == "__main__": ...@@ -97,8 +101,12 @@ if __name__ == "__main__":
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"), *ncond(
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"), "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
),
*ncond(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
),
"src/kernels/activation_kernels.cu", "src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu", "src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu", "src/kernels/misc_kernels.cu",
......
...@@ -128,7 +128,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -128,7 +128,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
assert(qkv.shape[2] == num_heads * dim_head * 3); assert(qkv.shape[2] == num_heads * dim_head * 3);
constexpr int POOL_SIZE = 128; constexpr int POOL_SIZE = 128;
const int pool_tokens = num_tokens / POOL_SIZE; const int pool_tokens = ceilDiv(num_tokens, POOL_SIZE);
Tensor blockmask; Tensor blockmask;
......
...@@ -1209,7 +1209,7 @@ public: ...@@ -1209,7 +1209,7 @@ public:
const bool is_q = bn < binfo.numBlocksN / 3; const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2; const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
assert(args.actualM == M); assert(!args.pool_out || args.actualM == M);
assert(args.actualN == N); assert(args.actualN == N);
if (is_q || is_k) { if (is_q || is_k) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment