Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
57e50f8d
Unverified
Commit
57e50f8d
authored
May 01, 2025
by
Muyang Li
Committed by
GitHub
May 01, 2025
Browse files
style: upgrade the linter (#339)
* style: reformated codes * style: reformated codes
parent
b737368d
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
41 additions
and
47 deletions
+41
-47
app/flux.1/fill/assets/description.html
app/flux.1/fill/assets/description.html
+2
-2
app/flux.1/fill/assets/style.css
app/flux.1/fill/assets/style.css
+1
-1
app/flux.1/fill/run_gradio.py
app/flux.1/fill/run_gradio.py
+8
-8
app/flux.1/redux/README.md
app/flux.1/redux/README.md
+1
-1
app/flux.1/redux/assets/description.html
app/flux.1/redux/assets/description.html
+2
-2
app/flux.1/redux/assets/style.css
app/flux.1/redux/assets/style.css
+1
-1
app/flux.1/redux/run_gradio.py
app/flux.1/redux/run_gradio.py
+5
-5
app/flux.1/redux/utils.py
app/flux.1/redux/utils.py
+1
-3
app/flux.1/sketch/README.md
app/flux.1/sketch/README.md
+1
-1
app/flux.1/sketch/assets/description.html
app/flux.1/sketch/assets/description.html
+2
-2
app/flux.1/sketch/assets/style.css
app/flux.1/sketch/assets/style.css
+1
-1
app/flux.1/sketch/run.py
app/flux.1/sketch/run.py
+0
-1
app/flux.1/sketch/run_gradio.py
app/flux.1/sketch/run_gradio.py
+6
-6
app/flux.1/t2i/assets/common.css
app/flux.1/t2i/assets/common.css
+1
-1
app/flux.1/t2i/assets/description.html
app/flux.1/t2i/assets/description.html
+2
-2
app/flux.1/t2i/evaluate.py
app/flux.1/t2i/evaluate.py
+1
-2
app/flux.1/t2i/generate.py
app/flux.1/t2i/generate.py
+0
-1
app/flux.1/t2i/latency.py
app/flux.1/t2i/latency.py
+0
-1
app/flux.1/t2i/metrics/image_reward.py
app/flux.1/t2i/metrics/image_reward.py
+1
-1
app/flux.1/t2i/run_gradio.py
app/flux.1/t2i/run_gradio.py
+5
-5
No files found.
app/flux.1/fill/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<div>
<h1>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-fill-dev Demo
INT4 FLUX.1-fill-dev Demo
...
@@ -49,4 +49,4 @@
...
@@ -49,4 +49,4 @@
</div>
</div>
{count_info}
{count_info}
</div>
</div>
</div>
</div>
\ No newline at end of file
app/flux.1/fill/assets/style.css
View file @
57e50f8d
...
@@ -37,4 +37,4 @@ h1 {
...
@@ -37,4 +37,4 @@ h1 {
#run_button
{
#run_button
{
height
:
87px
;
height
:
87px
;
}
}
\ No newline at end of file
app/flux.1/fill/run_gradio.py
View file @
57e50f8d
...
@@ -8,25 +8,25 @@ import GPUtil
...
@@ -8,25 +8,25 @@ import GPUtil
import
torch
import
torch
from
diffusers
import
FluxFillPipeline
from
diffusers
import
FluxFillPipeline
from
PIL
import
Image
from
PIL
import
Image
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
DEFAULT_STYLE_NAME
,
EXAMPLES
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
DEFAULT_STYLE_NAME
,
EXAMPLES
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
args
=
get_args
()
args
=
get_args
()
if
args
.
precision
==
"bf16"
:
if
args
.
precision
==
"bf16"
:
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
FluxFillPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"bf16"
pipeline
.
precision
=
"bf16"
else
:
else
:
assert
args
.
precision
==
"int4"
assert
args
.
precision
==
"int4"
pipeline_init_kwargs
=
{}
pipeline_init_kwargs
=
{}
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-int4-flux.1-fill-dev"
)
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-fill-dev"
)
pipeline_init_kwargs
[
"transformer"
]
=
transformer
pipeline_init_kwargs
[
"transformer"
]
=
transformer
if
args
.
use_qencoder
:
if
args
.
use_qencoder
:
from
nunchaku.models.text_encoders.t5_encoder
import
NunchakuT5EncoderModel
from
nunchaku.models.text_encoders.t5_encoder
import
NunchakuT5EncoderModel
...
@@ -35,7 +35,7 @@ else:
...
@@ -35,7 +35,7 @@ else:
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline_init_kwargs
[
"text_encoder_2"
]
=
text_encoder_2
pipeline
=
FluxFillPipeline
.
from_pretrained
(
pipeline
=
FluxFillPipeline
.
from_pretrained
(
f
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
"black-forest-labs/FLUX.1-Fill-dev"
,
torch_dtype
=
torch
.
bfloat16
,
**
pipeline_init_kwargs
)
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
.
precision
=
"int4"
pipeline
.
precision
=
"int4"
...
@@ -94,7 +94,7 @@ def run(
...
@@ -94,7 +94,7 @@ def run(
return
result_image
,
latency_str
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Flux.1-Fill-dev Sketch-to-Image Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
gpus
=
GPUtil
.
getGPUs
()
...
@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
...
@@ -104,7 +104,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Flux.1-Fill-dev Sk
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
def
get_header_str
():
...
...
app/flux.1/redux/README.md
View file @
57e50f8d
...
@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v
...
@@ -8,4 +8,4 @@ This interactive Gradio application allows you to interactively generate image v
python run_gradio.py
python run_gradio.py
```
```
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
\ No newline at end of file
app/flux.1/redux/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<div>
<h1>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-redux-dev Demo
INT4 FLUX.1-redux-dev Demo
...
@@ -46,4 +46,4 @@
...
@@ -46,4 +46,4 @@
</div>
</div>
{count_info}
{count_info}
</div>
</div>
</div>
</div>
\ No newline at end of file
app/flux.1/redux/assets/style.css
View file @
57e50f8d
...
@@ -26,4 +26,4 @@ h1{text-align:center}
...
@@ -26,4 +26,4 @@ h1{text-align:center}
}
}
#random_seed
{
height
:
71px
;}
#random_seed
{
height
:
71px
;}
#run_button
{
height
:
87px
;}
#run_button
{
height
:
87px
;}
\ No newline at end of file
app/flux.1/redux/run_gradio.py
View file @
57e50f8d
...
@@ -5,16 +5,16 @@ import time
...
@@ -5,16 +5,16 @@ import time
from
datetime
import
datetime
from
datetime
import
datetime
import
GPUtil
import
GPUtil
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
torch
import
torch
from
diffusers
import
FluxPipeline
,
FluxPriorReduxPipeline
from
diffusers
import
FluxPipeline
,
FluxPriorReduxPipeline
from
PIL
import
Image
from
PIL
import
Image
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
utils
import
get_args
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
EXAMPLES
,
MAX_SEED
from
vars
import
DEFAULT_GUIDANCE
,
DEFAULT_INFERENCE_STEP
,
EXAMPLES
,
MAX_SEED
# import gradio last to avoid conflicts with other imports
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
import
gradio
as
gr
args
=
get_args
()
args
=
get_args
()
...
@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
...
@@ -76,7 +76,7 @@ def run(image, num_inference_steps: int, guidance_scale: float, seed: int) -> tu
return
result_image
,
latency_str
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Flux.1-redux-dev Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Flux.1-redux-dev Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
gpus
=
GPUtil
.
getGPUs
()
...
...
app/flux.1/redux/utils.py
View file @
57e50f8d
...
@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace:
...
@@ -11,9 +11,7 @@ def get_args() -> argparse.Namespace:
choices
=
[
"int4"
,
"bf16"
],
choices
=
[
"int4"
,
"bf16"
],
help
=
"Which precisions to use"
,
help
=
"Which precisions to use"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
"--count-use"
,
action
=
"store_true"
,
help
=
"Whether to count the number of uses"
)
parser
.
add_argument
(
"--gradio-root-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--gradio-root-path"
,
type
=
str
,
default
=
""
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
app/flux.1/sketch/README.md
View file @
57e50f8d
...
@@ -12,4 +12,4 @@ python run_gradio.py
...
@@ -12,4 +12,4 @@ python run_gradio.py
*
The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use
`--no-safety-checker`
.
*
The demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use
`--no-safety-checker`
.
*
To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying
`--use-qencoder`
.
*
To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying
`--use-qencoder`
.
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
*
By default, we use our INT4 model. Use
`-p bf16`
to switch to the BF16 model.
\ No newline at end of file
app/flux.1/sketch/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<div>
<h1>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
INT4 FLUX.1-schnell Sketch-to-Image Demo
INT4 FLUX.1-schnell Sketch-to-Image Demo
...
@@ -50,4 +50,4 @@
...
@@ -50,4 +50,4 @@
</div>
</div>
{count_info}
{count_info}
</div>
</div>
</div>
</div>
\ No newline at end of file
app/flux.1/sketch/assets/style.css
View file @
57e50f8d
...
@@ -37,4 +37,4 @@ h1 {
...
@@ -37,4 +37,4 @@ h1 {
#run_button
{
#run_button
{
height
:
87px
;
height
:
87px
;
}
}
\ No newline at end of file
app/flux.1/sketch/run.py
View file @
57e50f8d
import
argparse
import
argparse
import
torch
import
torch
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
...
...
app/flux.1/sketch/run_gradio.py
View file @
57e50f8d
...
@@ -8,16 +8,16 @@ from datetime import datetime
...
@@ -8,16 +8,16 @@ from datetime import datetime
import
GPUtil
import
GPUtil
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
PIL
import
Image
from
PIL
import
Image
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
...
@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
...
@@ -109,7 +109,7 @@ def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed:
return
result_image
,
latency_str
return
result_image
,
latency_str
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
f
"SVDQuant Sketch-to-Image Demo"
)
as
demo
:
with
gr
.
Blocks
(
css_paths
=
"assets/style.css"
,
title
=
"SVDQuant Sketch-to-Image Demo"
)
as
demo
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
with
open
(
"assets/description.html"
,
"r"
)
as
f
:
DESCRIPTION
=
f
.
read
()
DESCRIPTION
=
f
.
read
()
gpus
=
GPUtil
.
getGPUs
()
gpus
=
GPUtil
.
getGPUs
()
...
@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
...
@@ -119,7 +119,7 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
def
get_header_str
():
def
get_header_str
():
...
...
app/flux.1/t2i/assets/common.css
View file @
57e50f8d
...
@@ -6,4 +6,4 @@ h2{text-align:center}
...
@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility
{
#accessibility
{
text-align
:
center
;
/* Center-aligns the text */
text-align
:
center
;
/* Center-aligns the text */
margin
:
auto
;
/* Centers the element horizontally */
margin
:
auto
;
/* Centers the element horizontally */
}
}
\ No newline at end of file
app/flux.1/t2i/assets/description.html
View file @
57e50f8d
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div
style=
"display: flex; justify-content: center; align-items: center; text-align: center;"
>
<div>
<div>
<h1>
<h1>
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
logo
.svg"
<img
src=
"https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/
svdquant
.svg"
alt=
"logo"
alt=
"logo"
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
style=
"height: 40px; width: auto; display: block; margin: auto;"
/>
FLUX.1-{model} Demo
FLUX.1-{model} Demo
...
@@ -50,4 +50,4 @@
...
@@ -50,4 +50,4 @@
</div>
</div>
{count_info}
{count_info}
</div>
</div>
</div>
</div>
\ No newline at end of file
app/flux.1/t2i/evaluate.py
View file @
57e50f8d
...
@@ -2,9 +2,8 @@ import argparse
...
@@ -2,9 +2,8 @@ import argparse
import
os
import
os
import
torch
import
torch
from
tqdm
import
tqdm
from
data
import
get_dataset
from
data
import
get_dataset
from
tqdm
import
tqdm
from
utils
import
get_pipeline
,
hash_str_to_int
from
utils
import
get_pipeline
,
hash_str_to_int
...
...
app/flux.1/t2i/generate.py
View file @
57e50f8d
...
@@ -2,7 +2,6 @@ import argparse
...
@@ -2,7 +2,6 @@ import argparse
import
os
import
os
import
torch
import
torch
from
utils
import
get_pipeline
from
utils
import
get_pipeline
from
vars
import
PROMPT_TEMPLATES
from
vars
import
PROMPT_TEMPLATES
...
...
app/flux.1/t2i/latency.py
View file @
57e50f8d
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
tqdm
import
trange
from
tqdm
import
trange
from
utils
import
get_pipeline
from
utils
import
get_pipeline
...
...
app/flux.1/t2i/metrics/image_reward.py
View file @
57e50f8d
import
os
import
os
import
ImageReward
as
RM
import
datasets
import
datasets
import
ImageReward
as
RM
import
torch
import
torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
...
app/flux.1/t2i/run_gradio.py
View file @
57e50f8d
...
@@ -9,13 +9,13 @@ import GPUtil
...
@@ -9,13 +9,13 @@ import GPUtil
import
spaces
import
spaces
import
torch
import
torch
from
peft.tuners
import
lora
from
peft.tuners
import
lora
from
nunchaku.models.safety_checker
import
SafetyChecker
from
utils
import
get_pipeline
from
utils
import
get_pipeline
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
from
nunchaku.models.safety_checker
import
SafetyChecker
# import gradio last to avoid conflicts with other imports
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
import
gradio
as
gr
# noqa: isort: skip
def
get_args
()
->
argparse
.
Namespace
:
def
get_args
()
->
argparse
.
Namespace
:
...
@@ -84,7 +84,7 @@ def generate(
...
@@ -84,7 +84,7 @@ def generate(
images
,
latency_strs
=
[],
[]
images
,
latency_strs
=
[],
[]
for
i
,
pipeline
in
enumerate
(
pipelines
):
for
i
,
pipeline
in
enumerate
(
pipelines
):
precision
=
args
.
precisions
[
i
]
precision
=
args
.
precisions
[
i
]
progress
=
gr
.
Progress
(
track_tqdm
=
True
)
gr
.
Progress
(
track_tqdm
=
True
)
if
pipeline
.
cur_lora_name
!=
lora_name
:
if
pipeline
.
cur_lora_name
!=
lora_name
:
if
precision
==
"bf16"
:
if
precision
==
"bf16"
:
for
m
in
pipeline
.
transformer
.
modules
():
for
m
in
pipeline
.
transformer
.
modules
():
...
@@ -164,7 +164,7 @@ if len(gpus) > 0:
...
@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
device_info
=
f
"Running on
{
gpu
.
name
}
with
{
memory
:.
0
f
}
GiB memory."
else
:
else
:
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
device_info
=
"Running on CPU 🥶 This demo does not work on CPU."
notice
=
f
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
notice
=
'<strong>Notice:</strong> We will replace unsafe prompts with a default prompt: "A peaceful world."'
with
gr
.
Blocks
(
with
gr
.
Blocks
(
css_paths
=
[
f
"assets/frame
{
len
(
args
.
precisions
)
}
.css"
,
"assets/common.css"
],
css_paths
=
[
f
"assets/frame
{
len
(
args
.
precisions
)
}
.css"
,
"assets/common.css"
],
...
...
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment