run_gradio.py 7.48 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
import argparse
Muyang Li's avatar
Muyang Li committed
3
import os
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
import random
import time
muyangli's avatar
muyangli committed
6
from datetime import datetime
Zhekai Zhang's avatar
Zhekai Zhang committed
7
8
9
10
11

import GPUtil
import spaces
import torch
from utils import get_pipeline
muyangli's avatar
muyangli committed
12
from vars import EXAMPLES, MAX_SEED
13

Muyang Li's avatar
Muyang Li committed
14
15
from nunchaku.models.safety_checker import SafetyChecker

16
# import gradio last to avoid conflicts with other imports
Muyang Li's avatar
Muyang Li committed
17
import gradio as gr  # noqa: isort: skip
18

Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-p",
        "--precisions",
        type=str,
        default=["int4"],
        nargs="*",
        choices=["int4", "bf16"],
        help="Which precisions to use",
    )
    parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
    parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
Muyang Li's avatar
Muyang Li committed
33
    parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
34
    parser.add_argument("--gradio-root-path", type=str, default="")
Zhekai Zhang's avatar
Zhekai Zhang committed
35
36
37
38
39
40
41
    return parser.parse_args()


args = get_args()


pipelines = []
muyangli's avatar
muyangli committed
42
pipeline_init_kwargs = {}
Zhekai Zhang's avatar
Zhekai Zhang committed
43
for i, precision in enumerate(args.precisions):
muyangli's avatar
muyangli committed
44

Zhekai Zhang's avatar
Zhekai Zhang committed
45
46
47
    pipeline = get_pipeline(
        precision=precision,
        use_qencoder=args.use_qencoder,
muyangli's avatar
muyangli committed
48
49
        device="cuda",
        pipeline_init_kwargs={**pipeline_init_kwargs},
Zhekai Zhang's avatar
Zhekai Zhang committed
50
51
    )
    pipelines.append(pipeline)
muyangli's avatar
muyangli committed
52
53
54
    if i == 0:
        pipeline_init_kwargs["vae"] = pipeline.vae
        pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
Zhekai Zhang's avatar
Zhekai Zhang committed
55
56
57
58
59
60
61
62
63
64
65

safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)


@spaces.GPU(enable_queue=True)
def generate(
    prompt: str = None,
    height: int = 1024,
    width: int = 1024,
    num_inference_steps: int = 4,
    guidance_scale: float = 0,
muyangli's avatar
muyangli committed
66
    pag_scale: float = 0,
Zhekai Zhang's avatar
Zhekai Zhang committed
67
68
    seed: int = 0,
):
muyangli's avatar
muyangli committed
69
    print(f"Prompt: {prompt}")
Zhekai Zhang's avatar
Zhekai Zhang committed
70
71
72
73
74
75
    is_unsafe_prompt = False
    if not safety_checker(prompt):
        is_unsafe_prompt = True
        prompt = "A peaceful world."
    images, latency_strs = [], []
    for i, pipeline in enumerate(pipelines):
Muyang Li's avatar
Muyang Li committed
76
        gr.Progress(track_tqdm=True)
Zhekai Zhang's avatar
Zhekai Zhang committed
77
78
79
80
81
82
        start_time = time.time()
        image = pipeline(
            prompt=prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
muyangli's avatar
muyangli committed
83
84
            pag_scale=pag_scale,
            num_inference_steps=num_inference_steps,
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        end_time = time.time()
        latency = end_time - start_time
        if latency < 1:
            latency = latency * 1000
            latency_str = f"{latency:.2f}ms"
        else:
            latency_str = f"{latency:.2f}s"
        images.append(image)
        latency_strs.append(latency_str)
    if is_unsafe_prompt:
        for i in range(len(latency_strs)):
            latency_strs[i] += " (Unsafe prompt detected)"
    torch.cuda.empty_cache()
Muyang Li's avatar
Muyang Li committed
100
101
102
103
104
105
106
107

    if args.count_use:
        if os.path.exists("use_count.txt"):
            with open("use_count.txt", "r") as f:
                count = int(f.read())
        else:
            count = 0
        count += 1
muyangli's avatar
muyangli committed
108
109
        current_time = datetime.now()
        print(f"{current_time}: {count}")
Muyang Li's avatar
Muyang Li committed
110
111
        with open("use_count.txt", "w") as f:
            f.write(str(count))
muyangli's avatar
muyangli committed
112
113
        with open("use_record.txt", "a") as f:
            f.write(f"{current_time}: {count}\n")
Muyang Li's avatar
Muyang Li committed
114

Zhekai Zhang's avatar
Zhekai Zhang committed
115
116
117
118
119
120
121
122
123
124
125
126
    return *images, *latency_strs


with open("./assets/description.html", "r") as f:
    DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
    gpu = gpus[0]
    memory = gpu.memoryTotal / 1024
    device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
    device_info = "Running on CPU 🥶 This demo does not work on CPU."
Muyang Li's avatar
Muyang Li committed
127
notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
Zhekai Zhang's avatar
Zhekai Zhang committed
128
129
130

with gr.Blocks(
    css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
Muyang Li's avatar
Muyang Li committed
131
    title="SVDQuant SANA-1600M Demo",
Zhekai Zhang's avatar
Zhekai Zhang committed
132
) as demo:
muyangli's avatar
muyangli committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def get_header_str():

        if args.count_use:
            if os.path.exists("use_count.txt"):
                with open("use_count.txt", "r") as f:
                    count = int(f.read())
            else:
                count = 0
            count_info = (
                f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
                f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
                f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
            )
        else:
            count_info = ""
        header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
        return header_str

    header = gr.HTML(get_header_str())
    demo.load(fn=get_header_str, outputs=header)

Zhekai Zhang's avatar
Zhekai Zhang committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    with gr.Row():
        image_results, latency_results = [], []
        for i, precision in enumerate(args.precisions):
            with gr.Column():
                gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
                with gr.Group():
                    image_result = gr.Image(
                        format="png",
                        image_mode="RGB",
                        label="Result",
                        show_label=False,
                        show_download_button=True,
                        interactive=False,
                    )
                    latency_result = gr.Text(label="Inference Latency", show_label=True)
                    image_results.append(image_result)
                    latency_results.append(latency_result)
    with gr.Row():
        prompt = gr.Text(
            label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
        )
        run_button = gr.Button("Run", scale=1)

    with gr.Row():
        seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
        randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
    with gr.Accordion("Advanced options", open=False):
        with gr.Group():
muyangli's avatar
muyangli committed
183
184
185
186
187
188
            height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
            width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
        with gr.Group():
            num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
            pag_scale = gr.Slider(label="PAG Scale", minimum=0, maximum=10, step=0.1, value=2.0)
Zhekai Zhang's avatar
Zhekai Zhang committed
189

muyangli's avatar
muyangli committed
190
    input_args = [prompt, height, width, num_inference_steps, guidance_scale, pag_scale, seed]
Zhekai Zhang's avatar
Zhekai Zhang committed
191

muyangli's avatar
muyangli committed
192
    gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
Zhekai Zhang's avatar
Zhekai Zhang committed
193
194
195

    gr.on(
        triggers=[prompt.submit, run_button.click],
muyangli's avatar
muyangli committed
196
        fn=generate,
Zhekai Zhang's avatar
Zhekai Zhang committed
197
198
        inputs=input_args,
        outputs=[*image_results, *latency_results],
muyangli's avatar
muyangli committed
199
        api_name=False,
Zhekai Zhang's avatar
Zhekai Zhang committed
200
201
202
    )
    randomize_seed.click(
        lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
muyangli's avatar
muyangli committed
203
204
    ).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)

205
206
    gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")

Zhekai Zhang's avatar
Zhekai Zhang committed
207
208

if __name__ == "__main__":
209
    demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True, root_path=args.gradio_root_path)