Unverified Commit 9ed2db76 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: FLUX Gradio demos support FP4 (#623)

* update app

* depth supports fp4

* update

* fix the demo website

* style: make linter happy
parent 17c7154a
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
alt="svdquant logo" style="height: 40px; width: auto;" /> alt="svdquant logo" style="height: 40px; width: auto;" />
</a> </a>
</div> </div>
<h1 style="margin-top: 0;">INT4 FLUX.1-{model_name}-dev Demo</h1> <h1 style="margin-top: 0;">{precision} FLUX.1-{model_name}-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
......
...@@ -51,10 +51,10 @@ if args.precision == "bf16": ...@@ -51,10 +51,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
assert args.precision == "int4" assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors" f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-{args.precision}_r32-flux.1-{model_name}.safetensors"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
...@@ -69,7 +69,7 @@ else: ...@@ -69,7 +69,7 @@ else:
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = args.precision
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
...@@ -154,7 +154,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name ...@@ -154,7 +154,11 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name
else: else:
count_info = "" count_info = ""
header_str = DESCRIPTION.format( header_str = DESCRIPTION.format(
model_name=args.model, device_info=device_info, notice=notice, count_info=count_info precision=args.precision,
model_name=args.model,
device_info=device_info,
notice=notice,
count_info=count_info,
) )
return header_str return header_str
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
) )
parser.add_argument( parser.add_argument(
"-m", "--model", type=str, default="depth", choices=["canny", "depth"], help="Which FLUX.1 model to use" "-m", "--model", type=str, default="depth", choices=["canny", "depth"], help="Which FLUX.1 model to use"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
</div> </div>
<!-- Title --> <!-- Title -->
<h1 style="margin-top: 0;">INT4 FLUX.1-fill-dev Demo</h1> <h1 style="margin-top: 0;">{precision} FLUX.1-fill-dev Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
</div> </div>
......
...@@ -23,10 +23,10 @@ if args.precision == "bf16": ...@@ -23,10 +23,10 @@ if args.precision == "bf16":
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
assert args.precision == "int4" assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-int4_r32-flux.1-fill-dev.safetensors" f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{args.precision}_r32-flux.1-fill-dev.safetensors"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
...@@ -41,7 +41,7 @@ else: ...@@ -41,7 +41,7 @@ else:
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = args.precision
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
...@@ -125,7 +125,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Ske ...@@ -125,7 +125,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Ske
) )
else: else:
count_info = "" count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info) header_str = DESCRIPTION.format(
precision=args.precision, device_info=device_info, notice=notice, count_info=count_info
)
return header_str return header_str
header = gr.HTML(get_header_str()) header = gr.HTML(get_header_str())
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
alt="svdquant logo" style="height: 40px; width: auto;" /> alt="svdquant logo" style="height: 40px; width: auto;" />
</a> </a>
</div> </div>
<h1 style="margin-top: 0;">INT4 FLUX.1-schnell Sketch-to-Image Demo</h1> <h1 style="margin-top: 0;">{precision} FLUX.1-schnell Sketch-to-Image Demo</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info} {device_info}
......
...@@ -8,57 +8,30 @@ from PIL import Image ...@@ -8,57 +8,30 @@ from PIL import Image
from torch import nn from torch import nn
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from nunchaku.utils import load_state_dict_in_safetensors
class FluxPix2pixTurboPipeline(FluxPipeline): class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None: def update_alpha(self, alpha: float) -> None:
self._alpha = alpha self._alpha = alpha
transformer = self.transformer transformer = self.transformer
for n, p in transformer.named_parameters():
if n in self._tuned_state_dict:
new_data = self._tuned_state_dict[n] * alpha + self._original_state_dict[n] * (1 - alpha)
new_data = new_data.to(self._execution_device).to(p.dtype)
p.data.copy_(new_data)
if self.precision == "bf16": if self.precision == "bf16":
for m in transformer.modules(): for m in transformer.modules():
if isinstance(m, lora.LoraLayer): if isinstance(m, lora.LoraLayer):
m.scaling["default_0"] = alpha m.scaling["default_0"] = alpha
else: else:
assert self.precision == "int4" assert self.precision in ["int4", "fp4"]
transformer.set_lora_strength(alpha) transformer.set_lora_strength(alpha)
def load_control_module( def load_control_module(self, pretrained_model_name_or_path: str, weight_name: str, alpha: float = 1):
self, state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path, weight_name=weight_name, return_alphas=True)
pretrained_model_name_or_path: str,
weight_name: str | None = None,
svdq_lora_path: str | None = None,
alpha: float = 1,
):
state_dict, alphas = self.lora_state_dict(
pretrained_model_name_or_path, weight_name=weight_name, return_alphas=True
)
transformer = self.transformer transformer = self.transformer
original_state_dict = {}
tuned_state_dict = {}
assert isinstance(transformer, FluxTransformer2DModel) assert isinstance(transformer, FluxTransformer2DModel)
for n, p in transformer.named_parameters():
if f"transformer.{n}" in state_dict:
original_state_dict[n] = p.data.cpu()
tuned_state_dict[n] = state_dict[f"transformer.{n}"].cpu()
self._original_state_dict = original_state_dict
self._tuned_state_dict = tuned_state_dict
if self.precision == "bf16": if self.precision == "bf16":
self.load_lora_into_transformer(state_dict, {}, transformer=transformer) self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else: else:
assert svdq_lora_path is not None self.transformer.update_lora_params(state_dict)
sd = load_state_dict_in_safetensors(svdq_lora_path)
sd = {k: v for k, v in sd.items() if not k.startswith("transformer.")}
self.transformer.update_lora_params(sd)
self.update_alpha(alpha) self.update_alpha(alpha)
@torch.no_grad() @torch.no_grad()
......
...@@ -29,10 +29,10 @@ if args.precision == "bf16": ...@@ -29,10 +29,10 @@ if args.precision == "bf16":
"mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "sketch.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE "mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "sketch.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE
) )
else: else:
assert args.precision == "int4" assert args.precision in ["int4", "fp4"]
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors" f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{args.precision}_r32-flux.1-schnell.safetensors"
) )
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
...@@ -47,11 +47,10 @@ else: ...@@ -47,11 +47,10 @@ else:
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = args.precision
pipeline.load_control_module( pipeline.load_control_module(
"mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo", "mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo",
"sketch.safetensors", "sketch.safetensors",
svdq_lora_path="mit-han-lab/svdq-flux.1-schnell-pix2pix-turbo/svdq-int4-sketch.safetensors",
alpha=DEFAULT_SKETCH_GUIDANCE, alpha=DEFAULT_SKETCH_GUIDANCE,
) )
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker) safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
...@@ -135,7 +134,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem ...@@ -135,7 +134,9 @@ with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Dem
) )
else: else:
count_info = "" count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info) header_str = DESCRIPTION.format(
precision=args.precision, device_info=device_info, notice=notice, count_info=count_info
)
return header_str return header_str
header = gr.HTML(get_header_str()) header = gr.HTML(get_header_str())
......
...@@ -4,7 +4,7 @@ import argparse ...@@ -4,7 +4,7 @@ import argparse
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precisions to use"
) )
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder") parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker") parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment