run_gradio.py 7.38 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
import argparse
Muyang Li's avatar
Muyang Li committed
3
import os
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
import random
import time
muyangli's avatar
muyangli committed
6
from datetime import datetime
Zhekai Zhang's avatar
Zhekai Zhang committed
7
8

import GPUtil
muyangli's avatar
muyangli committed
9
10
11

# import gradio last to avoid conflicts with other imports
import gradio as gr
Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
14
15
16
import spaces
import torch

from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
muyangli's avatar
muyangli committed
17
from vars import EXAMPLES, MAX_SEED
18

Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32

def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-p",
        "--precisions",
        type=str,
        default=["int4"],
        nargs="*",
        choices=["int4", "bf16"],
        help="Which precisions to use",
    )
    parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
    parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
Muyang Li's avatar
Muyang Li committed
33
    parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
Zhekai Zhang's avatar
Zhekai Zhang committed
34
35
36
37
38
39
40
    return parser.parse_args()


args = get_args()


pipelines = []
muyangli's avatar
muyangli committed
41
pipeline_init_kwargs = {}
Zhekai Zhang's avatar
Zhekai Zhang committed
42
for i, precision in enumerate(args.precisions):
muyangli's avatar
muyangli committed
43

Zhekai Zhang's avatar
Zhekai Zhang committed
44
45
46
    pipeline = get_pipeline(
        precision=precision,
        use_qencoder=args.use_qencoder,
muyangli's avatar
muyangli committed
47
48
        device="cuda",
        pipeline_init_kwargs={**pipeline_init_kwargs},
Zhekai Zhang's avatar
Zhekai Zhang committed
49
50
    )
    pipelines.append(pipeline)
muyangli's avatar
muyangli committed
51
52
53
    if i == 0:
        pipeline_init_kwargs["vae"] = pipeline.vae
        pipeline_init_kwargs["text_encoder"] = pipeline.text_encoder
Zhekai Zhang's avatar
Zhekai Zhang committed
54
55
56
57
58
59
60
61
62
63
64

safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)


@spaces.GPU(enable_queue=True)
def generate(
    prompt: str = None,
    height: int = 1024,
    width: int = 1024,
    num_inference_steps: int = 4,
    guidance_scale: float = 0,
muyangli's avatar
muyangli committed
65
    pag_scale: float = 0,
Zhekai Zhang's avatar
Zhekai Zhang committed
66
67
    seed: int = 0,
):
muyangli's avatar
muyangli committed
68
    print(f"Prompt: {prompt}")
Zhekai Zhang's avatar
Zhekai Zhang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
    is_unsafe_prompt = False
    if not safety_checker(prompt):
        is_unsafe_prompt = True
        prompt = "A peaceful world."
    images, latency_strs = [], []
    for i, pipeline in enumerate(pipelines):
        progress = gr.Progress(track_tqdm=True)
        start_time = time.time()
        image = pipeline(
            prompt=prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
muyangli's avatar
muyangli committed
82
83
            pag_scale=pag_scale,
            num_inference_steps=num_inference_steps,
Zhekai Zhang's avatar
Zhekai Zhang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        end_time = time.time()
        latency = end_time - start_time
        if latency < 1:
            latency = latency * 1000
            latency_str = f"{latency:.2f}ms"
        else:
            latency_str = f"{latency:.2f}s"
        images.append(image)
        latency_strs.append(latency_str)
    if is_unsafe_prompt:
        for i in range(len(latency_strs)):
            latency_strs[i] += " (Unsafe prompt detected)"
    torch.cuda.empty_cache()
Muyang Li's avatar
Muyang Li committed
99
100
101
102
103
104
105
106

    if args.count_use:
        if os.path.exists("use_count.txt"):
            with open("use_count.txt", "r") as f:
                count = int(f.read())
        else:
            count = 0
        count += 1
muyangli's avatar
muyangli committed
107
108
        current_time = datetime.now()
        print(f"{current_time}: {count}")
Muyang Li's avatar
Muyang Li committed
109
110
        with open("use_count.txt", "w") as f:
            f.write(str(count))
muyangli's avatar
muyangli committed
111
112
        with open("use_record.txt", "a") as f:
            f.write(f"{current_time}: {count}\n")
Muyang Li's avatar
Muyang Li committed
113

Zhekai Zhang's avatar
Zhekai Zhang committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    return *images, *latency_strs


with open("./assets/description.html", "r") as f:
    DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
    gpu = gpus[0]
    memory = gpu.memoryTotal / 1024
    device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
    device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'

with gr.Blocks(
    css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
muyangli's avatar
muyangli committed
130
    title=f"SVDQuant SANA-1600M Demo",
Zhekai Zhang's avatar
Zhekai Zhang committed
131
) as demo:
muyangli's avatar
muyangli committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    def get_header_str():

        if args.count_use:
            if os.path.exists("use_count.txt"):
                with open("use_count.txt", "r") as f:
                    count = int(f.read())
            else:
                count = 0
            count_info = (
                f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
                f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
                f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
            )
        else:
            count_info = ""
        header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
        return header_str

    header = gr.HTML(get_header_str())
    demo.load(fn=get_header_str, outputs=header)

Zhekai Zhang's avatar
Zhekai Zhang committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    with gr.Row():
        image_results, latency_results = [], []
        for i, precision in enumerate(args.precisions):
            with gr.Column():
                gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
                with gr.Group():
                    image_result = gr.Image(
                        format="png",
                        image_mode="RGB",
                        label="Result",
                        show_label=False,
                        show_download_button=True,
                        interactive=False,
                    )
                    latency_result = gr.Text(label="Inference Latency", show_label=True)
                    image_results.append(image_result)
                    latency_results.append(latency_result)
    with gr.Row():
        prompt = gr.Text(
            label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
        )
        run_button = gr.Button("Run", scale=1)

    with gr.Row():
        seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
        randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
    with gr.Accordion("Advanced options", open=False):
        with gr.Group():
muyangli's avatar
muyangli committed
182
183
184
185
186
187
            height = gr.Slider(label="Height", minimum=256, maximum=4096, step=32, value=1024)
            width = gr.Slider(label="Width", minimum=256, maximum=4096, step=32, value=1024)
        with gr.Group():
            num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=20)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=5)
            pag_scale = gr.Slider(label="PAG Scale", minimum=0, maximum=10, step=0.1, value=2.0)
Zhekai Zhang's avatar
Zhekai Zhang committed
188

muyangli's avatar
muyangli committed
189
    input_args = [prompt, height, width, num_inference_steps, guidance_scale, pag_scale, seed]
Zhekai Zhang's avatar
Zhekai Zhang committed
190

muyangli's avatar
muyangli committed
191
    gr.Examples(examples=EXAMPLES, inputs=input_args, outputs=[*image_results, *latency_results], fn=generate)
Zhekai Zhang's avatar
Zhekai Zhang committed
192
193
194

    gr.on(
        triggers=[prompt.submit, run_button.click],
muyangli's avatar
muyangli committed
195
        fn=generate,
Zhekai Zhang's avatar
Zhekai Zhang committed
196
197
198
199
200
201
        inputs=input_args,
        outputs=[*image_results, *latency_results],
        api_name="run",
    )
    randomize_seed.click(
        lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
muyangli's avatar
muyangli committed
202
203
    ).then(fn=generate, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)

204
205
    gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")

Zhekai Zhang's avatar
Zhekai Zhang committed
206
207
208

if __name__ == "__main__":
    demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)