Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-fill-dev Demo INT4 FLUX.1-fill-dev Demo
...@@ -49,4 +49,4 @@ ...@@ -49,4 +49,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -37,4 +37,4 @@ h1 { ...@@ -37,4 +37,4 @@ h1 {
#run_button { #run_button {
height: 87px; height: 87px;
} }
\ No newline at end of file
...@@ -8,25 +8,25 @@ import GPUtil ...@@ -8,25 +8,25 @@ import GPUtil
import torch import torch
from diffusers import FluxFillPipeline from diffusers import FluxFillPipeline
from PIL import Image from PIL import Image
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
args = get_args() args = get_args()
if args.precision == "bf16": if args.precision == "bf16":
pipeline = FluxFillPipeline.from_pretrained(f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) pipeline = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16" pipeline.precision = "bf16"
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {} pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-fill-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipeline_init_kwargs["transformer"] = transformer pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder: if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
...@@ -35,7 +35,7 @@ else: ...@@ -35,7 +35,7 @@ else:
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained( pipeline = FluxFillPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4" pipeline.precision = "int4"
...@@ -94,7 +94,7 @@ def run( ...@@ -94,7 +94,7 @@ def run(
return result_image, latency_str return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() gpus = GPUtil.getGPUs()
...@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk ...@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str(): def get_header_str():
......
...@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v ...@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v
python run_gradio.py python run_gradio.py
``` ```
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model. * By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-redux-dev Demo INT4 FLUX.1-redux-dev Demo
...@@ -46,4 +46,4 @@ ...@@ -46,4 +46,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -26,4 +26,4 @@ h1{text-align:center} ...@@ -26,4 +26,4 @@ h1{text-align:center}
} }
#random_seed {height: 71px;} #random_seed {height: 71px;}
#run_button {height: 87px;} #run_button {height: 87px;}
\ No newline at end of file
...@@ -5,16 +5,16 @@ import time ...@@ -5,16 +5,16 @@ import time
from datetime import datetime from datetime import datetime
import GPUtil import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import torch import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image from PIL import Image
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
# import gradio last to avoid conflicts with other imports from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
import gradio as gr
args = get_args() args = get_args()
...@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu ...@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
return result_image, latency_str return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-redux-dev Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() gpus = GPUtil.getGPUs()
......
...@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace: ...@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace:
choices=["int4", "bf16"], choices=["int4", "bf16"],
help="Which precisions to use", help="Which precisions to use",
) )
parser.add_argument( parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
"--count-use", action="store_true", help="Whether to count the number of uses"
)
parser.add_argument("--gradio-root-path", type=str, default="") parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -12,4 +12,4 @@ python run_gradio.py ...@@ -12,4 +12,4 @@ python run_gradio.py
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`. * The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`. * To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model. * By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-schnell Sketch-to-Image Demo INT4 FLUX.1-schnell Sketch-to-Image Demo
...@@ -50,4 +50,4 @@ ...@@ -50,4 +50,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -37,4 +37,4 @@ h1 { ...@@ -37,4 +37,4 @@ h1 {
#run_button { #run_button {
height: 87px; height: 87px;
} }
\ No newline at end of file
import argparse import argparse
import torch import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
......
...@@ -8,16 +8,16 @@ from datetime import datetime ...@@ -8,16 +8,16 @@ from datetime import datetime
import GPUtil import GPUtil
import numpy as np import numpy as np
import torch import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from PIL import Image from PIL import Image
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
...@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: ...@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
return result_image, latency_str return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image Demo") as demo: with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f: with open("assets/description.html", "r") as f:
DESCRIPTION = f.read() DESCRIPTION = f.read()
gpus = GPUtil.getGPUs() gpus = GPUtil.getGPUs()
...@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De ...@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str(): def get_header_str():
......
...@@ -6,4 +6,4 @@ h2{text-align:center} ...@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility { #accessibility {
text-align: center; /* Center-aligns the text */ text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */ margin: auto; /* Centers the element horizontally */
} }
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
FLUX.1-{model} Demo FLUX.1-{model} Demo
...@@ -50,4 +50,4 @@ ...@@ -50,4 +50,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -2,9 +2,8 @@ import argparse ...@@ -2,9 +2,8 @@ import argparse
import os import os
import torch import torch
from tqdm import tqdm
from data import get_dataset from data import get_dataset
from tqdm import tqdm
from utils import get_pipeline, hash_str_to_int from utils import get_pipeline, hash_str_to_int
......
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import os import os
import torch import torch
from utils import get_pipeline from utils import get_pipeline
from vars import PROMPT_TEMPLATES from vars import PROMPT_TEMPLATES
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import torch import torch
from torch import nn from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
......
import os import os
import ImageReward as RM
import datasets import datasets
import ImageReward as RM
import torch import torch
from tqdm import tqdm from tqdm import tqdm
......
...@@ -9,13 +9,13 @@ import GPUtil ...@@ -9,13 +9,13 @@ import GPUtil
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -84,7 +84,7 @@ def generate( ...@@ -84,7 +84,7 @@ def generate(
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
precision = args.precisions[i] precision = args.precisions[i]
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
if pipeline.cur_lora_name != lora_name: if pipeline.cur_lora_name != lora_name:
if precision == "bf16": if precision == "bf16":
for m in pipeline.transformer.modules(): for m in pipeline.transformer.modules():
...@@ -164,7 +164,7 @@ if len(gpus) > 0: ...@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment