Commit c17a2f6e authored by muyangli's avatar muyangli
Browse files

[major] add flux.1-redux; update the flux.1-tools demos

parent de9b25d6
......@@ -82,7 +82,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
image_type="sketch",
alpha=sketch_guidance,
prompt=prompt,
generator=torch.Generator().manual_seed(int(seed)),
generator=torch.Generator().manual_seed(seed),
).images[0]
latency = time.time() - start_time
......@@ -229,7 +229,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
......@@ -244,4 +244,4 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)
......@@ -9,5 +9,6 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
......@@ -2,6 +2,12 @@
## Text-to-Image Gradio Demo
![demo](./assets/demo.jpg)
This interactive Gradio application can generate an image based on your provided text prompt. The base model can be either [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) or [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
To launch the application, simply run:
```shell
python run_gradio.py
```
......
......@@ -35,6 +35,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
return parser.parse_args()
......@@ -282,4 +283,4 @@ with gr.Blocks(
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True, root_path=args.gradio_root_path)
......@@ -2,6 +2,10 @@
## Text-to-Image Gradio Demo
![demo](./assets/demo.jpg)
This interactive Gradio application can generate an image based on your provided text prompt. The base model is [SANA-1.6B](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers).
```shell
python run_gradio.py
```
......
......@@ -6,9 +6,6 @@ import time
from datetime import datetime
import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import spaces
import torch
......@@ -16,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED
# import gradio last to avoid conflicts with other imports
import gradio as gr
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
......@@ -31,6 +31,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
return parser.parse_args()
......@@ -205,4 +206,4 @@ with gr.Blocks(
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True, root_path=args.gradio_root_path)
......@@ -8,4 +8,4 @@ pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save("flux.1-dev.png")
image.save("flux.1-dev-int4.png")
......@@ -12,7 +12,7 @@ pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basked of several individual cartons of blueberries.",
prompt="A wooden basket of several individual cartons of blueberries.",
image=image,
mask_image=mask,
height=1024,
......@@ -21,4 +21,4 @@ image = pipe(
num_inference_steps=50,
max_sequence_length=512,
).images[0]
image.save("flux.1-fill-dev-int4.png")
image.save("flux.1-fill-dev.png")
import torch
from diffusers import FluxPriorReduxPipeline, FluxPipeline
from diffusers.utils import load_image
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
......@@ -7,5 +7,12 @@ transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-i
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("flux.1-schnell.png")
image = pipeline(
"A cat holding a sign that says hello world",
width=1024,
height=1024,
num_inference_steps=4,
guidance_scale=0,
generator=torch.Generator().manual_seed(2333),
).images[0]
image.save("flux.1-schnell-int4.png")
__version__ = "0.0.2beta2"
__version__ = "0.0.2beta3"
......@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download, utils
from packaging.version import Version
from torch import nn
from .utils import NunchakuModelLoaderMixin
from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils
SVD_RANK = 32
......@@ -52,6 +52,10 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)
rotary_emb_txt = pad_tensor(rotary_emb_txt, 256, 1)
rotary_emb_img = pad_tensor(rotary_emb_img, 256, 1)
rotary_emb_single = pad_tensor(rotary_emb_single, 256, 1)
hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
)
......
......@@ -4,6 +4,7 @@ import torch
from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from typing import Optional, Any
class NunchakuModelLoaderMixin:
......@@ -64,3 +65,20 @@ class NunchakuModelLoaderMixin:
transformer.load_state_dict(state_dict, strict=False)
return transformer, transformer_block_path
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor:
if multiples <= 1:
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if shape[dim] % multiples == 0:
return tensor
shape[dim] = ceil_div(shape[dim], multiples) * multiples
result = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
\ No newline at end of file
......@@ -3,6 +3,7 @@ import os
import setuptools
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
for ext in self.extensions:
......@@ -17,6 +18,7 @@ class CustomBuildExtension(BuildExtension):
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
if __name__ == "__main__":
fp = open("nunchaku/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])
......@@ -53,8 +55,12 @@ if __name__ == "__main__":
NVCC_FLAGS = [
"-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1",
"-gencode", "arch=compute_86,code=sm_86",
"-gencode", "arch=compute_89,code=sm_89",
"-gencode",
"arch=compute_86,code=sm_86",
"-gencode",
"arch=compute_89,code=sm_89",
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g",
"-std=c++20",
"-UNDEBUG",
......@@ -76,9 +82,7 @@ if __name__ == "__main__":
"--ptxas-options=--allow-expensive-optimizations=true",
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
NVCC_MSVC_FLAGS = [
"-Xcompiler", "/Zc:__cplusplus"
]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"]
nunchaku_extension = CUDAExtension(
name="nunchaku._C",
......@@ -97,8 +101,12 @@ if __name__ == "__main__":
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"),
*ncond(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
),
*ncond(
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
),
"src/kernels/activation_kernels.cu",
"src/kernels/layernorm_kernels.cu",
"src/kernels/misc_kernels.cu",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment