Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e6cd772c
Commit
e6cd772c
authored
Feb 10, 2025
by
April Hu
Browse files
Add flux1 demo for depth and canny
parent
6c333071
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
376 additions
and
0 deletions
+376
-0
app/flux.1/depth_canny/assets/description.html
app/flux.1/depth_canny/assets/description.html
+53
-0
app/flux.1/depth_canny/assets/style.css
app/flux.1/depth_canny/assets/style.css
+29
-0
app/flux.1/depth_canny/run_gradio.py
app/flux.1/depth_canny/run_gradio.py
+257
-0
app/flux.1/depth_canny/utils.py
app/flux.1/depth_canny/utils.py
+16
-0
app/flux.1/depth_canny/vars.py
app/flux.1/depth_canny/vars.py
+21
-0
No files found.
app/flux.1/depth_canny/assets/description.html
0 → 100644
View file @
e6cd772c
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-{model_name}-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a
href=
'https://lmxyy.me'
>
Muyang Li*
</a>
,
<a
href=
'https://yujunlin.com'
>
Yujun Lin*
</a>
,
<a
href=
'https://hanlab.mit.edu/team/zhekai-zhang'
>
Zhekai Zhang*
</a>
,
<a
href=
'https://www.tianle.website/#/'
>
Tianle Cai
</a>
,
<a
href=
'https://xiuyuli.com'
>
Xiuyu Li
</a>
,
<br>
<a
href=
'https://github.com/JerryGJX'
>
Junxian Guo
</a>
,
<a
href=
'https://xieenze.github.io'
>
Enze Xie
</a>
,
<a
href=
'https://cs.stanford.edu/~chenlin/'
>
Chenlin Meng
</a>
,
<a
href=
'https://www.cs.cmu.edu/~junyanz/'
>
Jun-Yan Zhu
</a>
,
and
<a
href=
'https://hanlab.mit.edu/songhan'
>
Song Han
</a>
</h3>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<a
href=
"https://arxiv.org/abs/2411.05007"
>
[Paper]
</a>
<a
href=
'https://github.com/mit-han-lab/nunchaku'
>
[Code]
</a>
<a
href=
'https://hanlab.mit.edu/projects/svdquant'
>
[Website]
</a>
<a
href=
'https://hanlab.mit.edu/blog/svdquant'
>
[Blog]
</a>
</div>
<h4>
Quantization Library:
<a
href=
'https://github.com/mit-han-lab/deepcompressor'
>
DeepCompressor
</a>
Inference Engine:
<a
href=
'https://github.com/mit-han-lab/nunchaku'
>
Nunchaku
</a>
Image Control:
<a
href=
"https://github.com/GaParmar/img2img-turbo"
>
img2img-turbo
</a>
</h4>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
{device_info}
</div>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
app/flux.1/depth_canny/assets/style.css
0 → 100644
View file @
e6cd772c
@import
url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css')
;
.gradio-container
{
max-width
:
1200px
!important
}
h1
{
text-align
:
center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j
{
display
:
none
;
}
#column_input
,
#column_output
{
width
:
500px
;
display
:
flex
;
align-items
:
center
;
}
#input_header
,
#output_header
{
display
:
flex
;
justify-content
:
center
;
align-items
:
center
;
width
:
400px
;
}
#accessibility
{
text-align
:
center
;
/* Center-aligns the text */
margin
:
auto
;
/* Centers the element horizontally */
}
#random_seed
{
height
:
71px
;}
#run_button
{
height
:
87px
;}
\ No newline at end of file
app/flux.1/depth_canny/run_gradio.py
0 → 100644
View file @
e6cd772c
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import
logging
import
os
import
random
import
tempfile
import
time
from
datetime
import
datetime
import
GPUtil
import
numpy
as
np
import
torch
from
PIL
import
Image
from
image_gen_aux
import
DepthPreprocessor
from
controlnet_aux
import
CannyDetector
from
diffusers
import
FluxControlPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_INFERENCE_STEP_CANNY
,
DEFAULT_GUIDANCE_CANNY
,
DEFAULT_INFERENCE_STEP_DEPTH
,
\
DEFAULT_GUIDANCE_DEPTH
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
args
=
get_args
()
pipeline_class
=
None
processor
=
None
model_name
=
None
model_name
=
f
"
{
args
.
model
}
-dev"
pipeline_class
=
FluxControlPipeline
if
args
.
model
==
"canny"
:
processor
=
CannyDetector
()
else
:
assert
args
.
model
==
"depth"
,
f
"Model
{
args
.
model
}
not suppported"
processor
=
DepthPreprocessor
.
from_pretrained
(
"LiheYoung/depth-anything-large-hf"
)
if
args
.
precision
==
"bf16"
:
pipeline
=
pipeline_class
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
model_name
.
capitalize
()
}
"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"bf16"
else
:
assert
args
.
precision
==
"int4"
pipeline_init_kwargs
=
{}
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-int4-flux.1-
{
model_name
}
"
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
args
.
use_qencoder
:
from
nunchaku.models.text_encoder
import
NunchakuT5EncoderModel
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
pipeline_class
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-
{
model_name
.
capitalize
()
}
"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"int4"
safety_checker
=
SafetyChecker
(
"cuda"
,
disabled
=
args
.
no_safety_checker
)
def
save_image
(
img
):
if
isinstance
(
img
,
dict
):
img
=
img
[
"composite"
]
temp_file
=
tempfile
.
NamedTemporaryFile
(
suffix
=
".png"
,
delete
=
False
)
img
.
save
(
temp_file
.
name
)
return
temp_file
.
name
def
run
(
image
,
prompt
:
str
,
prompt_template
:
str
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
seed
:
int
)
->
tuple
[
Image
,
str
]:
print
(
f
"Prompt:
{
prompt
}
"
)
if
args
.
model
==
"canny"
:
processed_img
=
processor
(
image
[
"composite"
]).
convert
(
"RGB"
)
else
:
assert
args
.
model
==
"depth"
processed_img
=
processor
(
image
[
"composite"
])[
0
].
convert
(
"RGB"
)
image_numpy
=
np
.
array
(
processed_img
)
if
prompt
.
strip
()
==
""
and
(
np
.
sum
(
image_numpy
==
255
)
>=
3145628
or
np
.
sum
(
image_numpy
==
0
)
>=
3145628
):
return
blank_image
,
"Please input the prompt or draw something."
is_unsafe_prompt
=
False
if
not
safety_checker
(
prompt
):
is_unsafe_prompt
=
True
prompt
=
"A peaceful world."
prompt
=
prompt_template
.
format
(
prompt
=
prompt
)
start_time
=
time
.
time
()
result_image
=
pipeline
(
prompt
=
prompt
,
control_image
=
processed_img
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
guidance_scale
,
generator
=
torch
.
Generator
().
manual_seed
(
int
(
seed
)),
).
images
[
0
]
latency
=
time
.
time
()
-
start_time
if
latency
<
1
:
latency
=
latency
*
1000
latency_str
=
f
"
{
latency
:.
2
f
}
ms"
else
:
latency_str
=
f
"
{
latency
:.
2
f
}
s"
if
is_unsafe_prompt
:
latency_str
+=
" (Unsafe prompt detected)"
torch
.
cuda
.
empty_cache
()
if
args
.
count_use
:
if
os
.
path
.
exists
(
"use_count.txt"
):
with
open
(
"use_count.txt"
,
"r"
)
as
f
:
count
=
int
(
f
.
read
())
else
:
count
=
0
count
+=
1
current_time
=
datetime
.
now
()
print
(
f
"
{
current_time
}
:
{
count
}
"
)
with
open
(
"use_count.txt"
,
"w"
)
as
f
:
f
.
write
(
str
(
count
))
with
open
(
"use_record.txt"
,
"a"
)
as
f
:
f
.
write
(
f
"
{
current_time
}
:
{
count
}
\n
"
)
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-
{
model_name
}
Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
if
len
(
gpus
)
>
0
:
gpu
=
gpus
[
0
]
memory
=
gpu
.
memoryTotal
/
1024
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
if
args
.
count_use
:
if
os
.
path
.
exists
(
"use_count.txt"
):
with
open
(
"use_count.txt"
,
"r"
)
as
f
:
count
=
int
(
f
.
read
())
else
:
count
=
0
count_info
=
(
f
"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f
"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f
"<span style='font-size: 18px; color:red; font-weight: bold;'>
{
count
}
</span></div>"
)
else
:
count_info
=
""
header_str
=
DESCRIPTION
.
format
(
model_name
=
args
.
model
,
device_info
=
device_info
,
notice
=
notice
,
count_info
=
count_info
)
return
header_str
header
=
gr
.
HTML
(
get_header_str
())
demo
.
load
(
fn
=
get_header_str
,
outputs
=
header
)
with
gr
.
Row
(
elem_id
=
"main_row"
):
with
gr
.
Column
(
elem_id
=
"column_input"
):
gr
.
Markdown
(
"## INPUT"
,
elem_id
=
"input_header"
)
with
gr
.
Group
():
canvas
=
gr
.
Sketchpad
(
value
=
blank_image
,
height
=
640
,
image_mode
=
"RGB"
,
sources
=
[
"upload"
,
"clipboard"
],
type
=
"pil"
,
label
=
"Sketch"
,
show_label
=
False
,
show_download_button
=
True
,
interactive
=
True
,
transforms
=
[],
canvas_size
=
(
1024
,
1024
),
scale
=
1
,
format
=
"png"
,
layers
=
False
,
)
with
gr
.
Row
():
prompt
=
gr
.
Text
(
label
=
"Prompt"
,
placeholder
=
"Enter your prompt"
,
scale
=
6
)
run_button
=
gr
.
Button
(
"Run"
,
scale
=
1
,
elem_id
=
"run_button"
)
download_sketch
=
gr
.
DownloadButton
(
"Download Sketch"
,
scale
=
1
,
elem_id
=
"download_sketch"
)
with
gr
.
Row
():
style
=
gr
.
Dropdown
(
label
=
"Style"
,
choices
=
STYLE_NAMES
,
value
=
DEFAULT_STYLE_NAME
,
scale
=
1
)
prompt_template
=
gr
.
Textbox
(
label
=
"Prompt Style Template"
,
value
=
STYLES
[
DEFAULT_STYLE_NAME
],
scale
=
2
,
max_lines
=
1
)
with
gr
.
Row
():
seed
=
gr
.
Slider
(
label
=
"Seed"
,
show_label
=
True
,
minimum
=
0
,
maximum
=
MAX_SEED
,
value
=
233
,
step
=
1
,
scale
=
4
)
randomize_seed
=
gr
.
Button
(
"Random Seed"
,
scale
=
1
,
min_width
=
50
,
elem_id
=
"random_seed"
)
with
gr
.
Accordion
(
"Advanced options"
,
open
=
False
):
with
gr
.
Group
():
num_inference_steps
=
gr
.
Slider
(
label
=
"Inference Steps"
,
minimum
=
10
,
maximum
=
50
,
step
=
1
,
\
value
=
DEFAULT_INFERENCE_STEP_CANNY
if
args
.
model
==
"canny"
else
DEFAULT_INFERENCE_STEP_DEPTH
)
guidance_scale
=
gr
.
Slider
(
label
=
"Guidance Scale"
,
minimum
=
1
,
maximum
=
50
,
step
=
1
,
\
value
=
DEFAULT_GUIDANCE_CANNY
if
args
.
model
==
"canny"
else
DEFAULT_GUIDANCE_DEPTH
)
with
gr
.
Column
(
elem_id
=
"column_output"
):
gr
.
Markdown
(
"## OUTPUT"
,
elem_id
=
"output_header"
)
with
gr
.
Group
():
result
=
gr
.
Image
(
format
=
"png"
,
height
=
640
,
image_mode
=
"RGB"
,
type
=
"pil"
,
label
=
"Result"
,
show_label
=
False
,
show_download_button
=
True
,
interactive
=
False
,
elem_id
=
"output_image"
,
)
latency_result
=
gr
.
Text
(
label
=
"Inference Latency"
,
show_label
=
True
)
download_result
=
gr
.
DownloadButton
(
"Download Result"
,
elem_id
=
"download_result"
)
gr
.
Markdown
(
"### Instructions"
)
gr
.
Markdown
(
"**1**. Enter a text prompt (e.g. a cat)"
)
gr
.
Markdown
(
"**2**. Start sketching"
)
gr
.
Markdown
(
"**3**. Change the image style using a style template"
)
gr
.
Markdown
(
"**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)"
)
gr
.
Markdown
(
"**5**. Try different seeds to generate different results"
)
run_inputs
=
[
canvas
,
prompt
,
prompt_template
,
num_inference_steps
,
guidance_scale
,
seed
]
run_outputs
=
[
result
,
latency_result
]
randomize_seed
.
click
(
lambda
:
random
.
randint
(
0
,
MAX_SEED
),
inputs
=
[],
outputs
=
seed
,
api_name
=
False
,
queue
=
False
,
).
then
(
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
)
style
.
change
(
lambda
x
:
STYLES
[
x
],
inputs
=
[
style
],
outputs
=
[
prompt_template
],
api_name
=
False
,
queue
=
False
,
).
then
(
fn
=
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
)
gr
.
on
(
triggers
=
[
prompt
.
submit
,
run_button
.
click
,
canvas
.
change
],
fn
=
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
,
)
download_sketch
.
click
(
fn
=
save_image
,
inputs
=
canvas
,
outputs
=
download_sketch
)
download_result
.
click
(
fn
=
save_image
,
inputs
=
result
,
outputs
=
download_result
)
gr
.
Markdown
(
"MIT Accessibility: https://accessibility.mit.edu/"
,
elem_id
=
"accessibility"
)
if
__name__
==
"__main__"
:
demo
.
queue
().
launch
(
debug
=
True
,
share
=
True
)
app/flux.1/depth_canny/utils.py
0 → 100644
View file @
e6cd772c
import
argparse
def
get_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-p"
,
"--precision"
,
type
=
str
,
default
=
"int4"
,
choices
=
[
"int4"
,
"bf16"
],
help
=
"Which precisions to use"
)
parser
.
add_argument
(
"-m"
,
"--model"
,
type
=
str
,
default
=
"canny"
,
choices
=
[
"canny"
,
"depth"
],
help
=
"Which FLUX.1 model to use"
)
parser
.
add_argument
(
"--use-qencoder"
,
action
=
"store_true"
,
help
=
"Whether to use 4-bit text encoder"
)
parser
.
add_argument
(
"--no-safety-checker"
,
action
=
"store_true"
,
help
=
"Disable safety checker"
)
parser
.
add_argument
(
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
args
=
parser
.
parse_args
()
return
args
app/flux.1/depth_canny/vars.py
0 → 100644
View file @
e6cd772c
STYLES
=
{
"None"
:
"{prompt}"
,
"Cinematic"
:
"cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"
,
"3D Model"
:
"professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting"
,
"Anime"
:
"anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed"
,
"Digital Art"
:
"concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed"
,
"Photographic"
:
"cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed"
,
"Pixel art"
:
"pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics"
,
"Fantasy art"
:
"ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"
,
"Neonpunk"
:
"neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"
,
"Manga"
:
"manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style"
,
}
DEFAULT_STYLE_NAME
=
"3D Model"
STYLE_NAMES
=
list
(
STYLES
.
keys
())
MAX_SEED
=
1000000000
DEFAULT_INFERENCE_STEP_CANNY
=
50
DEFAULT_GUIDANCE_CANNY
=
30.0
DEFAULT_INFERENCE_STEP_DEPTH
=
30
DEFAULT_GUIDANCE_DEPTH
=
10.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment