"magic_pdf/vscode:/vscode.git/clone" did not exist on "d1a9d1db2f69455843f5f726b3ded002226cda4c"
Unverified Commit 57e50f8d authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

style: upgrade the linter (#339)

* style: reformated codes

* style: reformated codes
parent b737368d
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-fill-dev Demo
......@@ -49,4 +49,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
......@@ -37,4 +37,4 @@ h1 {
#run_button {
height: 87px;
}
\ No newline at end of file
}
......@@ -8,25 +8,25 @@ import GPUtil
import torch
from diffusers import FluxFillPipeline
from PIL import Image
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, EXAMPLES, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
import gradio as gr # noqa: isort: skip
args = get_args()
if args.precision == "bf16":
pipeline = FluxFillPipeline.from_pretrained(f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
pipeline = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-fill-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
......@@ -35,7 +35,7 @@ else:
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
......@@ -94,7 +94,7 @@ def run(
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
......@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
......
......@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v
python run_gradio.py
```
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-redux-dev Demo
......@@ -46,4 +46,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
......@@ -26,4 +26,4 @@ h1{text-align:center}
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
#run_button {height: 87px;}
......@@ -5,16 +5,16 @@ import time
from datetime import datetime
import GPUtil
# import gradio last to avoid conflicts with other imports
import gradio as gr
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from PIL import Image
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, EXAMPLES, MAX_SEED
# import gradio last to avoid conflicts with other imports
import gradio as gr
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
args = get_args()
......@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-redux-dev Demo") as demo:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Flux.1-redux-dev Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
......
......@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace:
choices=["int4", "bf16"],
help="Which precisions to use",
)
parser.add_argument(
"--count-use", action="store_true", help="Whether to count the number of uses"
)
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
parser.add_argument("--gradio-root-path", type=str, default="")
args = parser.parse_args()
return args
......@@ -12,4 +12,4 @@ python run_gradio.py
* The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use `--no-safety-checker`.
* To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying `--use-qencoder`.
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
\ No newline at end of file
* By default, we use our INT4 model. Use `-p bf16` to switch to the BF16 model.
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-schnell Sketch-to-Image Demo
......@@ -50,4 +50,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
......@@ -37,4 +37,4 @@ h1 {
#run_button {
height: 87px;
}
\ No newline at end of file
}
import argparse
import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
......
......@@ -8,16 +8,16 @@ from datetime import datetime
import GPUtil
import numpy as np
import torch
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from PIL import Image
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
import gradio as gr # noqa: isort: skip
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
......@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image Demo") as demo:
with gr.Blocks(css_paths="assets/style.css", title="SVDQuant Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
......@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
......
......@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
\ No newline at end of file
}
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
FLUX.1-{model} Demo
......@@ -50,4 +50,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
......@@ -2,9 +2,8 @@ import argparse
import os
import torch
from tqdm import tqdm
from data import get_dataset
from tqdm import tqdm
from utils import get_pipeline, hash_str_to_int
......
......@@ -2,7 +2,6 @@ import argparse
import os
import torch
from utils import get_pipeline
from vars import PROMPT_TEMPLATES
......
......@@ -4,7 +4,6 @@ import time
import torch
from torch import nn
from tqdm import trange
from utils import get_pipeline
......
import os
import ImageReward as RM
import datasets
import ImageReward as RM
import torch
from tqdm import tqdm
......
......@@ -9,13 +9,13 @@ import GPUtil
import spaces
import torch
from peft.tuners import lora
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports
import gradio as gr
import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace:
......@@ -84,7 +84,7 @@ def generate(
images, latency_strs = [], []
for i, pipeline in enumerate(pipelines):
precision = args.precisions[i]
progress = gr.Progress(track_tqdm=True)
gr.Progress(track_tqdm=True)
if pipeline.cur_lora_name != lora_name:
if precision == "bf16":
for m in pipeline.transformer.modules():
......@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment