Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
41 additions
and
47 deletions
+41
-47
app/flux.1/fill/assets/description.html
app/flux.1/fill/assets/description.html
+2
-2
app/flux.1/fill/assets/style.css
app/flux.1/fill/assets/style.css
+1
-1
app/flux.1/fill/run_gradio.py
app/flux.1/fill/run_gradio.py
+8
-8
app/flux.1/redux/README.md
app/flux.1/redux/README.md
+1
-1
app/flux.1/redux/assets/description.html
app/flux.1/redux/assets/description.html
+2
-2
app/flux.1/redux/assets/style.css
app/flux.1/redux/assets/style.css
+1
-1
app/flux.1/redux/run_gradio.py
app/flux.1/redux/run_gradio.py
+5
-5
app/flux.1/redux/utils.py
app/flux.1/redux/utils.py
+1
-3
app/flux.1/sketch/README.md
app/flux.1/sketch/README.md
+1
-1
app/flux.1/sketch/assets/description.html
app/flux.1/sketch/assets/description.html
+2
-2
app/flux.1/sketch/assets/style.css
app/flux.1/sketch/assets/style.css
+1
-1
app/flux.1/sketch/run.py
app/flux.1/sketch/run.py
+0
-1
app/flux.1/sketch/run_gradio.py
app/flux.1/sketch/run_gradio.py
+6
-6
app/flux.1/t2i/assets/common.css
app/flux.1/t2i/assets/common.css
+1
-1
app/flux.1/t2i/assets/description.html
app/flux.1/t2i/assets/description.html
+2
-2
app/flux.1/t2i/evaluate.py
app/flux.1/t2i/evaluate.py
+1
-2
app/flux.1/t2i/generate.py
app/flux.1/t2i/generate.py
+0
-1
app/flux.1/t2i/latency.py
app/flux.1/t2i/latency.py
+0
-1
app/flux.1/t2i/metrics/image_reward.py
app/flux.1/t2i/metrics/image_reward.py
+1
-1
app/flux.1/t2i/run_gradio.py
app/flux.1/t2i/run_gradio.py
+5
-5
No files found.
app/flux.1/fill/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-fill-dev Demo
...
...
@@ -49,4 +49,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
app/flux.1/fill/assets/style.css
View file @
57e50f8d
...
...
@@ -37,4 +37,4 @@ h1 {
#run_button
{
height
:
87px
;
}
\ No newline at end of file
}
app/flux.1/fill/run_gradio.py
View file @
57e50f8d
...
...
@@ -8,25 +8,25 @@ import GPUtil
import
torch
from
diffusers
import
FluxFillPipeline
from
PIL
import
Image
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
DEFAULT_STYLE_NAME
,
EXAMPLES
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
DEFAULT_STYLE_NAME
,
EXAMPLES
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
args
=
get_args
()
if
args
.
precision
==
"bf16"
:
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
FluxFillPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"bf16"
else
:
assert
args
.
precision
==
"int4"
pipeline_init_kwargs
=
{}
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-int4-flux.1-fill-dev"
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-fill-dev"
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
args
.
use_qencoder
:
from
nunchaku.models.text_encoders.t5_encoder
import
NunchakuT5EncoderModel
...
...
@@ -35,7 +35,7 @@ else:
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"int4"
...
...
@@ -94,7 +94,7 @@ def run(
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
...
...
@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
...
...
app/flux.1/redux/README.md
View file @
57e50f8d
...
...
@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v
python run_gradio.py
```
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
\ No newline at end of file
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
app/flux.1/redux/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-redux-dev Demo
...
...
@@ -46,4 +46,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
app/flux.1/redux/assets/style.css
View file @
57e50f8d
...
...
@@ -26,4 +26,4 @@ h1{text-align:center}
}
#random_seed
{
height
:
71px
;}
#run_button
{
height
:
87px
;}
\ No newline at end of file
#run_button
{
height
:
87px
;}
app/flux.1/redux/run_gradio.py
View file @
57e50f8d
...
...
@@ -5,16 +5,16 @@ import time
from
datetime
import
datetime
import
GPUtil
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
torch
from
diffusers
import
FluxPipeline
,
FluxPriorReduxPipeline
from
PIL
import
Image
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
EXAMPLES
,
MAX_SEED
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
args
=
get_args
()
...
...
@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-redux-dev Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Flux.1-redux-dev Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
...
...
app/flux.1/redux/utils.py
View file @
57e50f8d
...
...
@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace:
choices
=
[
"int4"
,
"bf16"
],
help
=
"Which precisions to use"
,
)
parser
.
add_argument
(
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
parser
.
add_argument
(
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
parser
.
add_argument
(
"--gradio-root-path"
,
type
=
str
,
default
=
""
)
args
=
parser
.
parse_args
()
return
args
app/flux.1/sketch/README.md
View file @
57e50f8d
...
...
@@ -12,4 +12,4 @@ python run_gradio.py
*
The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use
`--no-safety-checker`
.
*
To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying
`--use-qencoder`
.
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
\ No newline at end of file
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
app/flux.1/sketch/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-schnell Sketch-to-Image Demo
...
...
@@ -50,4 +50,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
app/flux.1/sketch/assets/style.css
View file @
57e50f8d
...
...
@@ -37,4 +37,4 @@ h1 {
#run_button
{
height
:
87px
;
}
\ No newline at end of file
}
app/flux.1/sketch/run.py
View file @
57e50f8d
import
argparse
import
torch
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
...
...
app/flux.1/sketch/run_gradio.py
View file @
57e50f8d
...
...
@@ -8,16 +8,16 @@ from datetime import datetime
import
GPUtil
import
numpy
as
np
import
torch
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
PIL
import
Image
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
...
...
@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Sketch-to-Image Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Sketch-to-Image Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
...
...
@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
...
...
app/flux.1/t2i/assets/common.css
View file @
57e50f8d
...
...
@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility
{
text-align
:
center
;
/* Center-aligns the text */
margin
:
auto
;
/* Centers the element horizontally */
}
\ No newline at end of file
}
app/flux.1/t2i/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
FLUX.1-{model} Demo
...
...
@@ -50,4 +50,4 @@
</div>
{count_info}
</div>
</div>
\ No newline at end of file
</div>
app/flux.1/t2i/evaluate.py
View file @
57e50f8d
...
...
@@ -2,9 +2,8 @@ import argparse
import
os
import
torch
from
tqdm
import
tqdm
from
data
import
get_dataset
from
tqdm
import
tqdm
from
utils
import
get_pipeline
,
hash_str_to_int
...
...
app/flux.1/t2i/generate.py
View file @
57e50f8d
...
...
@@ -2,7 +2,6 @@ import argparse
import
os
import
torch
from
utils
import
get_pipeline
from
vars
import
PROMPT_TEMPLATES
...
...
app/flux.1/t2i/latency.py
View file @
57e50f8d
...
...
@@ -4,7 +4,6 @@ import time
import
torch
from
torch
import
nn
from
tqdm
import
trange
from
utils
import
get_pipeline
...
...
app/flux.1/t2i/metrics/image_reward.py
View file @
57e50f8d
import
os
import
ImageReward
as
RM
import
datasets
import
ImageReward
as
RM
import
torch
from
tqdm
import
tqdm
...
...
app/flux.1/t2i/run_gradio.py
View file @
57e50f8d
...
...
@@ -9,13 +9,13 @@ import GPUtil
import
spaces
import
torch
from
peft.tuners
import
lora
from
nunchaku.models.safety_checker
import
SafetyChecker
from
utils
import
get_pipeline
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
from
nunchaku.models.safety_checker
import
SafetyChecker
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
def
get_args
()
->
argparse
.
Namespace
:
...
...
@@ -84,7 +84,7 @@ def generate(
images
,
latency_strs
=
[],
[]
for
i
,
pipeline
in
enumerate
(
pipelines
):
precision
=
args
.
precisions
[
i
]
progress
=
gr
.
Progress
(
track_tqdm
=
True
)
gr
.
Progress
(
track_tqdm
=
True
)
if
pipeline
.
cur_lora_name
!=
lora_name
:
if
precision
==
"bf16"
:
for
m
in
pipeline
.
transformer
.
modules
():
...
...
@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
with
gr
.
Blocks
(
css_paths
=
[
f
"assets/frame
{
len
(
args
.
precisions
)
}
.css"
,
"assets/common.css"
],
...
...
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment