run_gradio.py 9.6 KB
Newer Older
April Hu's avatar
April Hu committed
1
2
3
4
5
6
7
8
9
10
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import os
import random
import time
from datetime import datetime

import GPUtil
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
11
12
13
from image_gen_aux import DepthPreprocessor
from PIL import Image

April Hu's avatar
April Hu committed
14
from nunchaku.models.safety_checker import SafetyChecker
muyangli's avatar
muyangli committed
15
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
April Hu's avatar
April Hu committed
16
from utils import get_args
17
18
19
20
21
22
23
24
25
26
27
28
29
from vars import (
    DEFAULT_GUIDANCE_CANNY,
    DEFAULT_GUIDANCE_DEPTH,
    DEFAULT_INFERENCE_STEP_CANNY,
    DEFAULT_INFERENCE_STEP_DEPTH,
    DEFAULT_STYLE_NAME,
    EXAMPLES,
    HEIGHT,
    MAX_SEED,
    STYLE_NAMES,
    STYLES,
    WIDTH,
)
April Hu's avatar
April Hu committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

# import gradio last to avoid conflicts with other imports
import gradio as gr

args = get_args()

pipeline_class = None
processor = None
model_name = None

model_name = f"{args.model}-dev"
pipeline_class = FluxControlPipeline
if args.model == "canny":
    processor = CannyDetector()
else:
45
    assert args.model == "depth", f"Model {args.model} not supported"
April Hu's avatar
April Hu committed
46
47
48
    processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")

if args.precision == "bf16":
49
50
51
    pipeline = pipeline_class.from_pretrained(
        f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16
    )
April Hu's avatar
April Hu committed
52
53
54
55
56
57
58
59
    pipeline = pipeline.to("cuda")
    pipeline.precision = "bf16"
else:
    assert args.precision == "int4"
    pipeline_init_kwargs = {}
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}")
    pipeline_init_kwargs["transformer"] = transformer
    if args.use_qencoder:
muyangli's avatar
muyangli committed
60
        from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
April Hu's avatar
April Hu committed
61
62
63
64
65
66
67
68
69
70
71
72
73

        text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
        pipeline_init_kwargs["text_encoder_2"] = text_encoder_2

    pipeline = pipeline_class.from_pretrained(
        f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
    )
    pipeline = pipeline.to("cuda")
    pipeline.precision = "int4"

safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)


74
75
76
def run(
    image, prompt: str, style: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int
) -> tuple[Image, str]:
April Hu's avatar
April Hu committed
77
78
79
80
81
82
83
84
85
86
87
    if args.model == "canny":
        processed_img = processor(image["composite"]).convert("RGB")
    else:
        assert args.model == "depth"
        processed_img = processor(image["composite"])[0].convert("RGB")

    is_unsafe_prompt = False
    if not safety_checker(prompt):
        is_unsafe_prompt = True
        prompt = "A peaceful world."
    prompt = prompt_template.format(prompt=prompt)
88
    print(f"Prompt: {prompt}")
April Hu's avatar
April Hu committed
89
90
    start_time = time.time()
    result_image = pipeline(
91
92
93
94
95
        prompt=prompt,
        control_image=processed_img,
        height=HEIGHT,
        width=WIDTH,
        num_inference_steps=num_inference_steps,
April Hu's avatar
April Hu committed
96
        guidance_scale=guidance_scale,
97
        generator=torch.Generator().manual_seed(seed),
April Hu's avatar
April Hu committed
98
99
100
101
102
103
104
105
106
107
108
109
    ).images[0]

    latency = time.time() - start_time
    if latency < 1:
        latency = latency * 1000
        latency_str = f"{latency:.2f}ms"
    else:
        latency_str = f"{latency:.2f}s"
    if is_unsafe_prompt:
        latency_str += " (Unsafe prompt detected)"
    torch.cuda.empty_cache()
    if args.count_use:
110
111
        if os.path.exists(f"{args.model}-use_count.txt"):
            with open(f"{args.model}-use_count.txt", "r") as f:
April Hu's avatar
April Hu committed
112
113
114
115
116
117
                count = int(f.read())
        else:
            count = 0
        count += 1
        current_time = datetime.now()
        print(f"{current_time}: {count}")
118
        with open(f"{args.model}-use_count.txt", "w") as f:
April Hu's avatar
April Hu committed
119
            f.write(str(count))
120
        with open(f"{args.model}-use_record.txt", "a") as f:
April Hu's avatar
April Hu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            f.write(f"{current_time}: {count}\n")
    return result_image, latency_str


with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
    with open("assets/description.html", "r") as f:
        DESCRIPTION = f.read()
    gpus = GPUtil.getGPUs()
    if len(gpus) > 0:
        gpu = gpus[0]
        memory = gpu.memoryTotal / 1024
        device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
    else:
        device_info = "Running on CPU 🥶 This demo does not work on CPU."
    notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'

    def get_header_str():

        if args.count_use:
140
141
            if os.path.exists(f"{args.model}-use_count.txt"):
                with open(f"{args.model}-use_count.txt", "r") as f:
April Hu's avatar
April Hu committed
142
143
144
145
146
147
148
149
150
151
                    count = int(f.read())
            else:
                count = 0
            count_info = (
                f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
                f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
                f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
            )
        else:
            count_info = ""
152
153
154
        header_str = DESCRIPTION.format(
            model_name=args.model, device_info=device_info, notice=notice, count_info=count_info
        )
April Hu's avatar
April Hu committed
155
156
157
158
159
160
161
162
163
        return header_str

    header = gr.HTML(get_header_str())
    demo.load(fn=get_header_str, outputs=header)

    with gr.Row(elem_id="main_row"):
        with gr.Column(elem_id="column_input"):
            gr.Markdown("## INPUT", elem_id="input_header")
            with gr.Group():
164
                canvas = gr.ImageEditor(
April Hu's avatar
April Hu committed
165
166
167
168
                    height=640,
                    image_mode="RGB",
                    sources=["upload", "clipboard"],
                    type="pil",
169
                    label="Input",
April Hu's avatar
April Hu committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                    show_label=False,
                    show_download_button=True,
                    interactive=True,
                    transforms=[],
                    canvas_size=(1024, 1024),
                    scale=1,
                    format="png",
                    layers=False,
                )
                with gr.Row():
                    prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
                    run_button = gr.Button("Run", scale=1, elem_id="run_button")
            with gr.Row():
                style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
                prompt_template = gr.Textbox(
                    label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
                )

            with gr.Row():
                seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
                randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
            with gr.Accordion("Advanced options", open=False):
                with gr.Group():
193
194
195
196
197
198
199
200
201
202
203
204
205
206
                    num_inference_steps = gr.Slider(
                        label="Inference Steps",
                        minimum=10,
                        maximum=50,
                        step=1,
                        value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH,
                    )
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=50,
                        step=1,
                        value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH,
                    )
April Hu's avatar
April Hu committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

        with gr.Column(elem_id="column_output"):
            gr.Markdown("## OUTPUT", elem_id="output_header")
            with gr.Group():
                result = gr.Image(
                    format="png",
                    height=640,
                    image_mode="RGB",
                    type="pil",
                    label="Result",
                    show_label=False,
                    show_download_button=True,
                    interactive=False,
                    elem_id="output_image",
                )
                latency_result = gr.Text(label="Inference Latency", show_label=True)

            gr.Markdown("### Instructions")
225
226
            gr.Markdown("**1**. Enter a text prompt (e.g., a cat)")
            gr.Markdown("**2**. Upload or paste an image")
April Hu's avatar
April Hu committed
227
            gr.Markdown("**3**. Change the image style using a style template")
228
            gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
April Hu's avatar
April Hu committed
229
230
            gr.Markdown("**5**. Try different seeds to generate different results")

231
    run_inputs = [canvas, prompt, style, prompt_template, num_inference_steps, guidance_scale, seed]
April Hu's avatar
April Hu committed
232
233
    run_outputs = [result, latency_result]

234
235
    gr.Examples(examples=EXAMPLES[args.model], inputs=run_inputs, outputs=run_outputs, fn=run)

April Hu's avatar
April Hu committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    randomize_seed.click(
        lambda: random.randint(0, MAX_SEED),
        inputs=[],
        outputs=seed,
        api_name=False,
        queue=False,
    ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)

    style.change(
        lambda x: STYLES[x],
        inputs=[style],
        outputs=[prompt_template],
        api_name=False,
        queue=False,
250
    )
April Hu's avatar
April Hu committed
251
    gr.on(
252
        triggers=[prompt.submit, run_button.click],
April Hu's avatar
April Hu committed
253
254
255
256
257
258
259
260
261
262
        fn=run,
        inputs=run_inputs,
        outputs=run_outputs,
        api_name=False,
    )

    gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")


if __name__ == "__main__":
263
    demo.queue().launch(debug=True, share=True, root_path=args.gradio_root_path)