run_gradio.py 9.86 KB
Newer Older
April Hu's avatar
April Hu committed
1
2
3
4
5
6
7
8
9
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random
import time
from datetime import datetime

import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
10
11
from image_gen_aux import DepthPreprocessor
from PIL import Image
April Hu's avatar
April Hu committed
12
from utils import get_args
13
14
15
16
17
18
19
20
21
22
23
24
25
from vars import (
    DEFAULT_GUIDANCE_CANNY,
    DEFAULT_GUIDANCE_DEPTH,
    DEFAULT_INFERENCE_STEP_CANNY,
    DEFAULT_INFERENCE_STEP_DEPTH,
    DEFAULT_STYLE_NAME,
    EXAMPLES,
    HEIGHT,
    MAX_SEED,
    STYLE_NAMES,
    STYLES,
    WIDTH,
)
April Hu's avatar
April Hu committed
26

Muyang Li's avatar
Muyang Li committed
27
28
29
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel

April Hu's avatar
April Hu committed
30
# import gradio last to avoid conflicts with other imports
Muyang Li's avatar
Muyang Li committed
31
import gradio as gr  # noqa: isort: skip
April Hu's avatar
April Hu committed
32
33
34
35
36
37
38
39
40
41
42
43

args = get_args()

pipeline_class = None
processor = None
model_name = None

model_name = f"{args.model}-dev"
pipeline_class = FluxControlPipeline
if args.model == "canny":
    processor = CannyDetector()
else:
44
    assert args.model == "depth", f"Model {args.model} not supported"
April Hu's avatar
April Hu committed
45
46
47
    processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")

if args.precision == "bf16":
48
49
50
    pipeline = pipeline_class.from_pretrained(
        f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16
    )
April Hu's avatar
April Hu committed
51
52
53
54
55
    pipeline = pipeline.to("cuda")
    pipeline.precision = "bf16"
else:
    assert args.precision == "int4"
    pipeline_init_kwargs = {}
56
57
58
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(
        f"mit-han-lab/nunchaku-flux.1-{model_name}/svdq-int4_r32-flux.1-{model_name}.safetensors"
    )
April Hu's avatar
April Hu committed
59
60
    pipeline_init_kwargs["transformer"] = transformer
    if args.use_qencoder:
muyangli's avatar
muyangli committed
61
        from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
April Hu's avatar
April Hu committed
62

63
64
65
        text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
            "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
        )
April Hu's avatar
April Hu committed
66
67
68
69
70
71
72
73
74
75
76
        pipeline_init_kwargs["text_encoder_2"] = text_encoder_2

    pipeline = pipeline_class.from_pretrained(
        f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
    )
    pipeline = pipeline.to("cuda")
    pipeline.precision = "int4"

safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)


77
78
79
def run(
    image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
) -> tuple[Image, str]:
April Hu's avatar
April Hu committed
80
81
82
83
84
85
86
87
88
89
90
    if args.model == "canny":
        processed_img = processor(image["composite"]).convert("RGB")
    else:
        assert args.model == "depth"
        processed_img = processor(image["composite"])[0].convert("RGB")

    is_unsafe_prompt = False
    if not safety_checker(prompt):
        is_unsafe_prompt = True
        prompt = "A peaceful world."
    prompt = prompt_template.format(prompt=prompt)
91
    print(f"Prompt: {prompt}")
April Hu's avatar
April Hu committed
92
93
    start_time = time.time()
    result_image = pipeline(
94
95
96
97
98
        prompt=prompt,
        control_image=processed_img,
        height=HEIGHT,
        width=WIDTH,
        num_inference_steps=num_inference_steps,
April Hu's avatar
April Hu committed
99
        guidance_scale=guidance_scale,
100
        generator=torch.Generator().manual_seed(seed),
April Hu's avatar
April Hu committed
101
102
103
104
105
106
107
108
109
110
111
112
    ).images[0]

    latency = time.time() - start_time
    if latency < 1:
        latency = latency * 1000
        latency_str = f"{latency:.2f}ms"
    else:
        latency_str = f"{latency:.2f}s"
    if is_unsafe_prompt:
        latency_str += " (Unsafe prompt detected)"
    torch.cuda.empty_cache()
    if args.count_use:
113
114
        if os.path.exists(f"{args.model}-use_count.txt"):
            with open(f"{args.model}-use_count.txt", "r") as f:
April Hu's avatar
April Hu committed
115
116
117
118
119
120
                count = int(f.read())
        else:
            count = 0
        count += 1
        current_time = datetime.now()
        print(f"{current_time}: {count}")
121
        with open(f"{args.model}-use_count.txt", "w") as f:
April Hu's avatar
April Hu committed
122
            f.write(str(count))
123
        with open(f"{args.model}-use_record.txt", "a") as f:
April Hu's avatar
April Hu committed
124
125
126
127
128
129
130
            f.write(f"{current_time}: {count}\n")
    return result_image, latency_str


with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
    with open("assets/description.html", "r") as f:
        DESCRIPTION = f.read()
131
132
133
134
135
136
    # Get the GPU properties
    if torch.cuda.device_count() > 0:
        gpu_properties = torch.cuda.get_device_properties(0)
        gpu_memory = gpu_properties.total_memory / (1024**3)  # Convert to GiB
        gpu_name = torch.cuda.get_device_name(0)
        device_info = f"Running on {gpu_name} with {gpu_memory:.0f} GiB memory."
April Hu's avatar
April Hu committed
137
138
    else:
        device_info = "Running on CPU 🥶 This demo does not work on CPU."
Muyang Li's avatar
Muyang Li committed
139
    notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
April Hu's avatar
April Hu committed
140
141
142
143

    def get_header_str():

        if args.count_use:
144
145
            if os.path.exists(f"{args.model}-use_count.txt"):
                with open(f"{args.model}-use_count.txt", "r") as f:
April Hu's avatar
April Hu committed
146
147
148
149
150
151
152
153
154
155
                    count = int(f.read())
            else:
                count = 0
            count_info = (
                f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
                f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
                f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
            )
        else:
            count_info = ""
156
157
158
        header_str = DESCRIPTION.format(
            model_name=args.model, device_info=device_info, notice=notice, count_info=count_info
        )
April Hu's avatar
April Hu committed
159
160
161
162
163
164
165
166
167
        return header_str

    header = gr.HTML(get_header_str())
    demo.load(fn=get_header_str, outputs=header)

    with gr.Row(elem_id="main_row"):
        with gr.Column(elem_id="column_input"):
            gr.Markdown("## INPUT", elem_id="input_header")
            with gr.Group():
168
                canvas = gr.ImageEditor(
April Hu's avatar
April Hu committed
169
170
171
172
                    height=640,
                    image_mode="RGB",
                    sources=["upload", "clipboard"],
                    type="pil",
173
                    label="Input",
April Hu's avatar
April Hu committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                    show_label=False,
                    show_download_button=True,
                    interactive=True,
                    transforms=[],
                    canvas_size=(1024, 1024),
                    scale=1,
                    format="png",
                    layers=False,
                )
                with gr.Row():
                    prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
                    run_button = gr.Button("Run", scale=1, elem_id="run_button")
            with gr.Row():
                style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
                prompt_template = gr.Textbox(
                    label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
                )

            with gr.Row():
                seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
                randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
            with gr.Accordion("Advanced options", open=False):
                with gr.Group():
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                    num_inference_steps = gr.Slider(
                        label="Inference Steps",
                        minimum=10,
                        maximum=50,
                        step=1,
                        value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH,
                    )
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH,
                    )
April Hu's avatar
April Hu committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        with gr.Column(elem_id="column_output"):
            gr.Markdown("## OUTPUT", elem_id="output_header")
            with gr.Group():
                result = gr.Image(
                    format="png",
                    height=640,
                    image_mode="RGB",
                    type="pil",
                    label="Result",
                    show_label=False,
                    show_download_button=True,
                    interactive=False,
                    elem_id="output_image",
                )
                latency_result = gr.Text(label="Inference Latency", show_label=True)

            gr.Markdown("### Instructions")
229
230
            gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
            gr.Markdown("**2**. Upload or paste an image")
April Hu's avatar
April Hu committed
231
            gr.Markdown("**3**. Change the image style using a style template")
232
            gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
April Hu's avatar
April Hu committed
233
234
            gr.Markdown("**5**. Try different seeds to generate different results")

235
    run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
April Hu's avatar
April Hu committed
236
237
    run_outputs = [result, latency_result]

238
239
    gr.Examples(examples=EXAMPLES[args.model], inputs=run_inputs, outputs=run_outputs, fn=run)

April Hu's avatar
April Hu committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    randomize_seed.click(
        lambda: random.randint(0, MAX_SEED),
        inputs=[],
        outputs=seed,
        api_name=False,
        queue=False,
    ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)

    style.change(
        lambda x: STYLES[x],
        inputs=[style],
        outputs=[prompt_template],
        api_name=False,
        queue=False,
254
    )
April Hu's avatar
April Hu committed
255
    gr.on(
256
        triggers=[prompt.submit, run_button.click],
April Hu's avatar
April Hu committed
257
258
259
260
261
262
263
264
265
266
        fn=run,
        inputs=run_inputs,
        outputs=run_outputs,
        api_name=False,
    )

    gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")


if __name__ == "__main__":
267
    demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)