Unverified Commit 9ed2db76 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: FLUX Gradio demos support FP4 (#623)

* update app

* depth supports fp4

* update

* fix the demo website

* style: make linter happy
parent 17c7154a
......@@ -11,7 +11,7 @@
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1>
<h1 style="margin-top: 0;">{precision} FLUX.1-{model_name}-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
......@@ -51,10 +51,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors"
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-{args.precision}_r32-flux.1-{model_name}.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
......@@ -69,7 +69,7 @@ else:
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.precision = args.precision
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
......@@ -154,7 +154,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
else:
count_info = ""
header_str = DESCRIPTION.format(
model_name=args.model, device_info=device_info, notice=notice, count_info=count_info
precision=args.precision,
model_name=args.model,
device_info=device_info,
notice=notice,
count_info=count_info,
)
return header_str
......
......@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
)
parser.add_argument(
"-m", "--model", type=str, default="depth", choices=["canny", "depth"], help="Which FLUX.1 model to use"
......
......@@ -13,7 +13,7 @@
</div>
<!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1>
<h1 style="margin-top: 0;">{precision} FLUX.1-fill-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
......
......@@ -23,10 +23,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-int4_r32-flux.1-fill-dev.safetensors"
f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{args.precision}_r32-flux.1-fill-dev.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
......@@ -41,7 +41,7 @@ else:
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.precision = args.precision
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
......@@ -125,7 +125,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Ske
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
header_str = DESCRIPTION.format(
precision=args.precision, device_info=device_info, notice=notice, count_info=count_info
)
return header_str
header = gr.HTML(get_header_str())
......
......@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
......@@ -11,7 +11,7 @@
alt="svdquant logo" style="height: 40px; width: auto;" />
</a>
</div>
<h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1>
<h1 style="margin-top: 0;">{precision} FLUX.1-schnell Sketch-to-Image Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
......
......@@ -8,57 +8,30 @@ from PIL import Image
from torch import nn
from torchvision.transforms import functional as F
from nunchaku.utils import load_state_dict_in_safetensors
class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None:
self._alpha = alpha
transformer = self.transformer
for n, p in transformer.named_parameters():
if n in self._tuned_state_dict:
new_data = self._tuned_state_dict[n] * alpha + self._original_state_dict[n] * (1 - alpha)
new_data = new_data.to(self._execution_device).to(p.dtype)
p.data.copy_(new_data)
if self.precision == "bf16":
for m in transformer.modules():
if isinstance(m, lora.LoraLayer):
m.scaling["default_0"] = alpha
else:
assert self.precision == "int4"
assert self.precision in ["int4", "fp4"]
transformer.set_lora_strength(alpha)
def load_control_module(
self,
pretrained_model_name_or_path: str,
weight_name: str | None = None,
svdq_lora_path: str | None = None,
alpha: float = 1,
):
state_dict, alphas = self.lora_state_dict(
pretrained_model_name_or_path, weight_name=weight_name, return_alphas=True
)
def load_control_module(self, pretrained_model_name_or_path: str, weight_name: str, alpha: float = 1):
state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path, weight_name=weight_name, return_alphas=True)
transformer = self.transformer
original_state_dict = {}
tuned_state_dict = {}
assert isinstance(transformer, FluxTransformer2DModel)
for n, p in transformer.named_parameters():
if f"transformer.{n}" in state_dict:
original_state_dict[n] = p.data.cpu()
tuned_state_dict[n] = state_dict[f"transformer.{n}"].cpu()
self._original_state_dict = original_state_dict
self._tuned_state_dict = tuned_state_dict
if self.precision == "bf16":
self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else:
assert svdq_lora_path is not None
sd = load_state_dict_in_safetensors(svdq_lora_path)
sd = {k: v for k, v in sd.items() if not k.startswith("transformer.")}
self.transformer.update_lora_params(sd)
self.transformer.update_lora_params(state_dict)
self.update_alpha(alpha)
@torch.no_grad()
......
......@@ -29,10 +29,10 @@ if args.precision == "bf16":
"mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "sketch.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE
)
else:
assert args.precision == "int4"
assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{args.precision}_r32-flux.1-schnell.safetensors"
)
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
......@@ -47,11 +47,10 @@ else:
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.precision = args.precision
pipeline.load_control_module(
"mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo",
"sketch.safetensors",
svdq_lora_path="mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo/svdq-int4-sketch.safetensors",
alpha=DEFAULT_SKETCH_GUIDANCE,
)
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
......@@ -135,7 +134,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
header_str = DESCRIPTION.format(
precision=args.precision, device_info=device_info, notice=notice, count_info=count_info
)
return header_str
header = gr.HTML(get_header_str())
......
......@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment