Unverified Commit de9b25d6 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #92 from Aprilhuu/main

[Major] Add demo for Flux.1 canny, depth and fill
parents 6c333071 50139c73
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-{model_name}-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from PIL import Image
from image_gen_aux import DepthPreprocessor
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_INFERENCE_STEP_CANNY, DEFAULT_GUIDANCE_CANNY, DEFAULT_INFERENCE_STEP_DEPTH, \
DEFAULT_GUIDANCE_DEPTH, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
pipeline_class = None
processor = None
model_name = None
model_name = f"{args.model}-dev"
pipeline_class = FluxControlPipeline
if args.model == "canny":
processor = CannyDetector()
else:
assert args.model == "depth", f"Model {args.model} not suppported"
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
if args.precision == "bf16":
pipeline = pipeline_class.from_pretrained(f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-{model_name}")
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_class.from_pretrained(
f"black-forest-labs/FLUX.1-{model_name.capitalize()}", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
if args.model == "canny":
processed_img = processor(image["composite"]).convert("RGB")
else:
assert args.model == "depth"
processed_img = processor(image["composite"])[0].convert("RGB")
image_numpy = np.array(processed_img)
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
start_time = time.time()
result_image = pipeline(
prompt=prompt,
control_image=processed_img,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(int(seed)),
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
if is_unsafe_prompt:
latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-{model_name} Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(model_name=args.model, device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.Sketchpad(
value=blank_image,
height=640,
image_mode="RGB",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
)
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, \
value=DEFAULT_INFERENCE_STEP_CANNY if args.model == "canny" else DEFAULT_INFERENCE_STEP_DEPTH)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, \
value=DEFAULT_GUIDANCE_CANNY if args.model == "canny" else DEFAULT_GUIDANCE_DEPTH)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
style.change(
lambda x: STYLES[x],
inputs=[style],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument(
"-m", "--model", type=str, default="canny", choices=["canny", "depth"], help="Which FLUX.1 model to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
args = parser.parse_args()
return args
STYLES = {
"None": "{prompt}",
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
DEFAULT_INFERENCE_STEP_CANNY = 50
DEFAULT_GUIDANCE_CANNY = 30.0
DEFAULT_INFERENCE_STEP_DEPTH = 30
DEFAULT_GUIDANCE_DEPTH = 10.0
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div>
<h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/>
INT4 FLUX.1-fill-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a href='https://lmxyy.me'>Muyang Li*</a>,
<a href='https://yujunlin.com'>Yujun Lin*</a>,
<a href='https://hanlab.mit.edu/team/zhekai-zhang'>Zhekai Zhang*</a>,
<a href='https://www.tianle.website/#/'>Tianle Cai</a>,
<a href='https://xiuyuli.com'>Xiuyu Li</a>,
<br>
<a href='https://github.com/JerryGJX'>Junxian Guo</a>,
<a href='https://xieenze.github.io'>Enze Xie</a>,
<a href='https://cs.stanford.edu/~chenlin/'>Chenlin Meng</a>,
<a href='https://www.cs.cmu.edu/~junyanz/'>Jun-Yan Zhu</a>,
and <a href='https://hanlab.mit.edu/songhan'>Song Han</a>
</h3>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://arxiv.org/abs/2411.05007">[Paper]</a>
&nbsp;
<a href='https://github.com/mit-han-lab/nunchaku'>
[Code]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/projects/svdquant'>
[Website]
</a>
&nbsp;
<a href='https://hanlab.mit.edu/blog/svdquant'>
[Blog]
</a>
</div>
<h4>Quantization Library:
<a href='https://github.com/mit-han-lab/deepcompressor'>DeepCompressor</a>&nbsp;
Inference Engine: <a href='https://github.com/mit-han-lab/nunchaku'>Nunchaku</a>&nbsp;
Image Control: <a href="https://github.com/GaParmar/img2img-turbo">img2img-turbo</a>
</h4>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{device_info}
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css');
.gradio-container{max-width: 1200px !important}
h1{text-align:center}
.wrap.svelte-p4aq0j.svelte-p4aq0j {
display: none;
}
#column_input, #column_output {
width: 500px;
display: flex;
align-items: center;
}
#input_header, #output_header {
display: flex;
justify-content: center;
align-items: center;
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import logging
import os
import random
import tempfile
import time
from datetime import datetime
import GPUtil
import numpy as np
import torch
from PIL import Image
from diffusers import FluxFillPipeline
from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args
from vars import DEFAULT_GUIDANCE, DEFAULT_INFERENCE_STEP, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
args = get_args()
if args.precision == "bf16":
pipeline = FluxFillPipeline.from_pretrained(f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
else:
assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-int4-flux.1-fill-dev")
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxFillPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
def save_image(img):
if isinstance(img, dict):
img = img["composite"]
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name)
return temp_file.name
def run(image, prompt: str, prompt_template: str, num_inference_steps: int, guidance_scale: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something."
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = prompt_template.format(prompt=prompt)
mask = image["layers"][0].getchannel(3) # Mask is stored in the last channel
pic = image["background"].convert("RGB") # This is the original photo
start_time = time.time()
result_image = pipeline(
prompt=prompt,
image=pic,
mask_image=mask,
guidance_scale=guidance_scale,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator().manual_seed(seed)
).images[0]
latency = time.time() - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
if is_unsafe_prompt:
latency_str += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count += 1
current_time = datetime.now()
print(f"{current_time}: {count}")
with open("use_count.txt", "w") as f:
f.write(str(count))
with open("use_record.txt", "a") as f:
f.write(f"{current_time}: {count}\n")
return result_image, latency_str
with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo") as demo:
with open("assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
def get_header_str():
if args.count_use:
if os.path.exists("use_count.txt"):
with open("use_count.txt", "r") as f:
count = int(f.read())
else:
count = 0
count_info = (
f"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f"<span style='font-size: 18px; color:red; font-weight: bold;'>&nbsp;{count}</span></div>"
)
else:
count_info = ""
header_str = DESCRIPTION.format(device_info=device_info, notice=notice, count_info=count_info)
return header_str
header = gr.HTML(get_header_str())
demo.load(fn=get_header_str, outputs=header)
with gr.Row(elem_id="main_row"):
with gr.Column(elem_id="column_input"):
gr.Markdown("## INPUT", elem_id="input_header")
with gr.Group():
canvas = gr.ImageMask(
value=blank_image,
height=640,
image_mode="RGBA",
sources=["upload", "clipboard"],
type="pil",
label="Sketch",
show_label=False,
show_download_button=True,
interactive=True,
transforms=[],
canvas_size=(1024, 1024),
scale=1,
format="png",
layers=False,
)
with gr.Row():
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
run_button = gr.Button("Run", scale=1, elem_id="run_button")
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
with gr.Row():
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
prompt_template = gr.Textbox(
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
)
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=DEFAULT_INFERENCE_STEP)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=50, step=1, value=DEFAULT_GUIDANCE)
with gr.Column(elem_id="column_output"):
gr.Markdown("## OUTPUT", elem_id="output_header")
with gr.Group():
result = gr.Image(
format="png",
height=640,
image_mode="RGB",
type="pil",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
elem_id="output_image",
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
gr.Markdown("### Instructions")
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
gr.Markdown("**2**. Start sketching")
gr.Markdown("**3**. Change the image style using a style template")
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)")
gr.Markdown("**5**. Try different seeds to generate different results")
run_inputs = [canvas, prompt, prompt_template, num_inference_steps, guidance_scale, seed]
run_outputs = [result, latency_result]
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
api_name=False,
queue=False,
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
style.change(
lambda x: STYLES[x],
inputs=[style],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
gr.on(
triggers=[prompt.submit, run_button.click, canvas.change],
fn=run,
inputs=run_inputs,
outputs=run_outputs,
api_name=False,
)
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
import argparse
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precisions to use"
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
parser.add_argument("--count-use", action="store_true", help="Whether to count the number of uses")
args = parser.parse_args()
return args
STYLES = {
"None": "{prompt}",
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
}
DEFAULT_STYLE_NAME = "3D Model"
STYLE_NAMES = list(STYLES.keys())
MAX_SEED = 1000000000
DEFAULT_GUIDANCE = 30
DEFAULT_INFERENCE_STEP = 50
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment