Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
50139c73
Commit
50139c73
authored
Feb 10, 2025
by
April Hu
Browse files
Add flux1 demo for fill
parent
e6cd772c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
351 additions
and
0 deletions
+351
-0
app/flux.1/fill/assets/description.html
app/flux.1/fill/assets/description.html
+53
-0
app/flux.1/fill/assets/style.css
app/flux.1/fill/assets/style.css
+29
-0
app/flux.1/fill/run_gradio.py
app/flux.1/fill/run_gradio.py
+238
-0
app/flux.1/fill/utils.py
app/flux.1/fill/utils.py
+13
-0
app/flux.1/fill/vars.py
app/flux.1/fill/vars.py
+18
-0
No files found.
app/flux.1/fill/assets/description.html
0 → 100644
View file @
50139c73
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-fill-dev Demo
</h1>
<h2>
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
</h2>
<h3>
<a
href=
'https://lmxyy.me'
>
Muyang Li*
</a>
,
<a
href=
'https://yujunlin.com'
>
Yujun Lin*
</a>
,
<a
href=
'https://hanlab.mit.edu/team/zhekai-zhang'
>
Zhekai Zhang*
</a>
,
<a
href=
'https://www.tianle.website/#/'
>
Tianle Cai
</a>
,
<a
href=
'https://xiuyuli.com'
>
Xiuyu Li
</a>
,
<br>
<a
href=
'https://github.com/JerryGJX'
>
Junxian Guo
</a>
,
<a
href=
'https://xieenze.github.io'
>
Enze Xie
</a>
,
<a
href=
'https://cs.stanford.edu/~chenlin/'
>
Chenlin Meng
</a>
,
<a
href=
'https://www.cs.cmu.edu/~junyanz/'
>
Jun-Yan Zhu
</a>
,
and
<a
href=
'https://hanlab.mit.edu/songhan'
>
Song Han
</a>
</h3>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<a
href=
"https://arxiv.org/abs/2411.05007"
>
[Paper]
</a>
<a
href=
'https://github.com/mit-han-lab/nunchaku'
>
[Code]
</a>
<a
href=
'https://hanlab.mit.edu/projects/svdquant'
>
[Website]
</a>
<a
href=
'https://hanlab.mit.edu/blog/svdquant'
>
[Blog]
</a>
</div>
<h4>
Quantization Library:
<a
href=
'https://github.com/mit-han-lab/deepcompressor'
>
DeepCompressor
</a>
Inference Engine:
<a
href=
'https://github.com/mit-han-lab/nunchaku'
>
Nunchaku
</a>
Image Control:
<a
href=
"https://github.com/GaParmar/img2img-turbo"
>
img2img-turbo
</a>
</h4>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
{device_info}
</div>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
{notice}
</div>
{count_info}
</div>
</div>
\ No newline at end of file
app/flux.1/fill/assets/style.css
0 → 100644
View file @
50139c73
@import
url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/css/all.min.css')
;
.gradio-container
{
max-width
:
1200px
!important
}
h1
{
text-align
:
center
}
.wrap.svelte-p4aq0j.svelte-p4aq0j
{
display
:
none
;
}
#column_input
,
#column_output
{
width
:
500px
;
display
:
flex
;
align-items
:
center
;
}
#input_header
,
#output_header
{
display
:
flex
;
justify-content
:
center
;
align-items
:
center
;
width
:
400px
;
}
#accessibility
{
text-align
:
center
;
/* Center-aligns the text */
margin
:
auto
;
/* Centers the element horizontally */
}
#random_seed
{
height
:
71px
;}
#run_button
{
height
:
87px
;}
\ No newline at end of file
app/flux.1/fill/run_gradio.py
0 → 100644
View file @
50139c73
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
import
logging
import
os
import
random
import
tempfile
import
time
from
datetime
import
datetime
import
GPUtil
import
numpy
as
np
import
torch
from
PIL
import
Image
from
diffusers
import
FluxFillPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
args
=
get_args
()
if
args
.
precision
==
"bf16"
:
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"bf16"
else
:
assert
args
.
precision
==
"int4"
pipeline_init_kwargs
=
{}
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-int4-flux.1-fill-dev"
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
args
.
use_qencoder
:
from
nunchaku.models.text_encoder
import
NunchakuT5EncoderModel
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/svdq-flux.1-t5"
)
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"int4"
safety_checker
=
SafetyChecker
(
"cuda"
,
disabled
=
args
.
no_safety_checker
)
def
save_image
(
img
):
if
isinstance
(
img
,
dict
):
img
=
img
[
"composite"
]
temp_file
=
tempfile
.
NamedTemporaryFile
(
suffix
=
".png"
,
delete
=
False
)
img
.
save
(
temp_file
.
name
)
return
temp_file
.
name
def
run
(
image
,
prompt
:
str
,
prompt_template
:
str
,
num_inference_steps
:
int
,
guidance_scale
:
float
,
seed
:
int
)
->
tuple
[
Image
,
str
]:
print
(
f
"Prompt:
{
prompt
}
"
)
image_numpy
=
np
.
array
(
image
[
"composite"
].
convert
(
"RGB"
))
if
prompt
.
strip
()
==
""
and
(
np
.
sum
(
image_numpy
==
255
)
>=
3145628
or
np
.
sum
(
image_numpy
==
0
)
>=
3145628
):
return
blank_image
,
"Please input the prompt or draw something."
is_unsafe_prompt
=
False
if
not
safety_checker
(
prompt
):
is_unsafe_prompt
=
True
prompt
=
"A peaceful world."
prompt
=
prompt_template
.
format
(
prompt
=
prompt
)
mask
=
image
[
"layers"
][
0
].
getchannel
(
3
)
# Mask is stored in the last channel
pic
=
image
[
"background"
].
convert
(
"RGB"
)
# This is the original photo
start_time
=
time
.
time
()
result_image
=
pipeline
(
prompt
=
prompt
,
image
=
pic
,
mask_image
=
mask
,
guidance_scale
=
guidance_scale
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
num_inference_steps
,
max_sequence_length
=
512
,
generator
=
torch
.
Generator
().
manual_seed
(
seed
)
).
images
[
0
]
latency
=
time
.
time
()
-
start_time
if
latency
<
1
:
latency
=
latency
*
1000
latency_str
=
f
"
{
latency
:.
2
f
}
ms"
else
:
latency_str
=
f
"
{
latency
:.
2
f
}
s"
if
is_unsafe_prompt
:
latency_str
+=
" (Unsafe prompt detected)"
torch
.
cuda
.
empty_cache
()
if
args
.
count_use
:
if
os
.
path
.
exists
(
"use_count.txt"
):
with
open
(
"use_count.txt"
,
"r"
)
as
f
:
count
=
int
(
f
.
read
())
else
:
count
=
0
count
+=
1
current_time
=
datetime
.
now
()
print
(
f
"
{
current_time
}
:
{
count
}
"
)
with
open
(
"use_count.txt"
,
"w"
)
as
f
:
f
.
write
(
str
(
count
))
with
open
(
"use_record.txt"
,
"a"
)
as
f
:
f
.
write
(
f
"
{
current_time
}
:
{
count
}
\n
"
)
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
if
len
(
gpus
)
>
0
:
gpu
=
gpus
[
0
]
memory
=
gpu
.
memoryTotal
/
1024
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
if
args
.
count_use
:
if
os
.
path
.
exists
(
"use_count.txt"
):
with
open
(
"use_count.txt"
,
"r"
)
as
f
:
count
=
int
(
f
.
read
())
else
:
count
=
0
count_info
=
(
f
"<div style='display: flex; justify-content: center; align-items: center; text-align: center;'>"
f
"<span style='font-size: 18px; font-weight: bold;'>Total inference runs: </span>"
f
"<span style='font-size: 18px; color:red; font-weight: bold;'>
{
count
}
</span></div>"
)
else
:
count_info
=
""
header_str
=
DESCRIPTION
.
format
(
device_info
=
device_info
,
notice
=
notice
,
count_info
=
count_info
)
return
header_str
header
=
gr
.
HTML
(
get_header_str
())
demo
.
load
(
fn
=
get_header_str
,
outputs
=
header
)
with
gr
.
Row
(
elem_id
=
"main_row"
):
with
gr
.
Column
(
elem_id
=
"column_input"
):
gr
.
Markdown
(
"## INPUT"
,
elem_id
=
"input_header"
)
with
gr
.
Group
():
canvas
=
gr
.
ImageMask
(
value
=
blank_image
,
height
=
640
,
image_mode
=
"RGBA"
,
sources
=
[
"upload"
,
"clipboard"
],
type
=
"pil"
,
label
=
"Sketch"
,
show_label
=
False
,
show_download_button
=
True
,
interactive
=
True
,
transforms
=
[],
canvas_size
=
(
1024
,
1024
),
scale
=
1
,
format
=
"png"
,
layers
=
False
,
)
with
gr
.
Row
():
prompt
=
gr
.
Text
(
label
=
"Prompt"
,
placeholder
=
"Enter your prompt"
,
scale
=
6
)
run_button
=
gr
.
Button
(
"Run"
,
scale
=
1
,
elem_id
=
"run_button"
)
download_sketch
=
gr
.
DownloadButton
(
"Download Sketch"
,
scale
=
1
,
elem_id
=
"download_sketch"
)
with
gr
.
Row
():
style
=
gr
.
Dropdown
(
label
=
"Style"
,
choices
=
STYLE_NAMES
,
value
=
DEFAULT_STYLE_NAME
,
scale
=
1
)
prompt_template
=
gr
.
Textbox
(
label
=
"Prompt Style Template"
,
value
=
STYLES
[
DEFAULT_STYLE_NAME
],
scale
=
2
,
max_lines
=
1
)
with
gr
.
Row
():
seed
=
gr
.
Slider
(
label
=
"Seed"
,
show_label
=
True
,
minimum
=
0
,
maximum
=
MAX_SEED
,
value
=
233
,
step
=
1
,
scale
=
4
)
randomize_seed
=
gr
.
Button
(
"Random Seed"
,
scale
=
1
,
min_width
=
50
,
elem_id
=
"random_seed"
)
with
gr
.
Accordion
(
"Advanced options"
,
open
=
False
):
with
gr
.
Group
():
num_inference_steps
=
gr
.
Slider
(
label
=
"Inference Steps"
,
minimum
=
10
,
maximum
=
50
,
step
=
1
,
value
=
DEFAULT_INFERENCE_STEP
)
guidance_scale
=
gr
.
Slider
(
label
=
"Guidance Scale"
,
minimum
=
1
,
maximum
=
50
,
step
=
1
,
value
=
DEFAULT_GUIDANCE
)
with
gr
.
Column
(
elem_id
=
"column_output"
):
gr
.
Markdown
(
"## OUTPUT"
,
elem_id
=
"output_header"
)
with
gr
.
Group
():
result
=
gr
.
Image
(
format
=
"png"
,
height
=
640
,
image_mode
=
"RGB"
,
type
=
"pil"
,
label
=
"Result"
,
show_label
=
False
,
show_download_button
=
True
,
interactive
=
False
,
elem_id
=
"output_image"
,
)
latency_result
=
gr
.
Text
(
label
=
"Inference Latency"
,
show_label
=
True
)
download_result
=
gr
.
DownloadButton
(
"Download Result"
,
elem_id
=
"download_result"
)
gr
.
Markdown
(
"### Instructions"
)
gr
.
Markdown
(
"**1**. Enter a text prompt (e.g. a cat)"
)
gr
.
Markdown
(
"**2**. Start sketching"
)
gr
.
Markdown
(
"**3**. Change the image style using a style template"
)
gr
.
Markdown
(
"**4**. Adjust the effect of sketch guidance using the slider (typically between 0.2 and 0.4)"
)
gr
.
Markdown
(
"**5**. Try different seeds to generate different results"
)
run_inputs
=
[
canvas
,
prompt
,
prompt_template
,
num_inference_steps
,
guidance_scale
,
seed
]
run_outputs
=
[
result
,
latency_result
]
randomize_seed
.
click
(
lambda
:
random
.
randint
(
0
,
MAX_SEED
),
inputs
=
[],
outputs
=
seed
,
api_name
=
False
,
queue
=
False
,
).
then
(
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
)
style
.
change
(
lambda
x
:
STYLES
[
x
],
inputs
=
[
style
],
outputs
=
[
prompt_template
],
api_name
=
False
,
queue
=
False
,
).
then
(
fn
=
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
)
gr
.
on
(
triggers
=
[
prompt
.
submit
,
run_button
.
click
,
canvas
.
change
],
fn
=
run
,
inputs
=
run_inputs
,
outputs
=
run_outputs
,
api_name
=
False
,
)
download_sketch
.
click
(
fn
=
save_image
,
inputs
=
canvas
,
outputs
=
download_sketch
)
download_result
.
click
(
fn
=
save_image
,
inputs
=
result
,
outputs
=
download_result
)
gr
.
Markdown
(
"MIT Accessibility: https://accessibility.mit.edu/"
,
elem_id
=
"accessibility"
)
if
__name__
==
"__main__"
:
demo
.
queue
().
launch
(
debug
=
True
,
share
=
True
)
app/flux.1/fill/utils.py
0 → 100644
View file @
50139c73
import
argparse
def
get_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-p"
,
"--precision"
,
type
=
str
,
default
=
"int4"
,
choices
=
[
"int4"
,
"bf16"
],
help
=
"Which precisions to use"
)
parser
.
add_argument
(
"--use-qencoder"
,
action
=
"store_true"
,
help
=
"Whether to use 4-bit text encoder"
)
parser
.
add_argument
(
"--no-safety-checker"
,
action
=
"store_true"
,
help
=
"Disable safety checker"
)
parser
.
add_argument
(
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
args
=
parser
.
parse_args
()
return
args
app/flux.1/fill/vars.py
0 → 100644
View file @
50139c73
STYLES
=
{
"None"
:
"{prompt}"
,
"Cinematic"
:
"cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"
,
"3D Model"
:
"professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting"
,
"Anime"
:
"anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed"
,
"Digital Art"
:
"concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed"
,
"Photographic"
:
"cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed"
,
"Pixel art"
:
"pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics"
,
"Fantasy art"
:
"ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"
,
"Neonpunk"
:
"neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"
,
"Manga"
:
"manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style"
,
}
DEFAULT_STYLE_NAME
=
"3D Model"
STYLE_NAMES
=
list
(
STYLES
.
keys
())
MAX_SEED
=
1000000000
DEFAULT_GUIDANCE
=
30
DEFAULT_INFERENCE_STEP
=
50
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment