Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
df981d24
Commit
df981d24
authored
Nov 28, 2024
by
muyangli
Browse files
[major] fix the evaluation scripts; no need to download the entire model
parent
25ce8942
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
95 additions
and
48 deletions
+95
-48
README.md
README.md
+4
-4
app/i2i/flux_pix2pix_pipeline.py
app/i2i/flux_pix2pix_pipeline.py
+11
-10
app/i2i/run_gradio.py
app/i2i/run_gradio.py
+6
-4
app/i2i/utils.py
app/i2i/utils.py
+0
-3
app/t2i/data/DCI/DCI.py
app/t2i/data/DCI/DCI.py
+1
-1
app/t2i/get_metrics.py
app/t2i/get_metrics.py
+3
-1
app/t2i/latency.py
app/t2i/latency.py
+23
-7
app/t2i/run_gradio.py
app/t2i/run_gradio.py
+3
-1
app/t2i/utils.py
app/t2i/utils.py
+2
-2
nunchaku/models/flux.py
nunchaku/models/flux.py
+10
-7
nunchaku/pipelines/flux.py
nunchaku/pipelines/flux.py
+32
-8
No files found.
README.md
View file @
df981d24
...
...
@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
Then build the package from source:
```
shell
git clone https://github.com/mit-han-lab/nunchaku.git
cd
nunchaku
git submodule init
git submodule update
cd
nunchaku
git submodule init
git submodule update
pip
install
-e
.
```
...
...
@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
# download from Huggingface
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
# download from Huggingface
).
to
(
"cuda"
)
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"example.png"
)
...
...
app/i2i/flux_pix2pix_pipeline.py
View file @
df981d24
...
...
@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union
import
torch
import
torchvision.transforms.functional
as
F
import
torchvision.utils
from
PIL
import
Image
from
diffusers.pipelines.flux.pipeline_flux
import
FluxPipeline
,
FluxPipelineOutput
,
FluxTransformer2DModel
from
einops
import
rearrange
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
peft.tuners
import
lora
from
PIL
import
Image
from
torch
import
nn
from
nunchaku.models.flux
import
inject_pipeline
,
load_quantized_model
...
...
@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self
.
erosion_kernel
=
erosion_kernel
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"before.png"
)
image_t
=
(
nn
.
functional
.
conv2d
(
image_t
[:,
:
1
],
erosion_kernel
,
padding
=
kernel_size
//
2
)
>
kernel_size
**
2
-
0.1
)
image_t
=
nn
.
functional
.
conv2d
(
image_t
[:,
:
1
],
erosion_kernel
,
padding
=
kernel_size
//
2
)
>
kernel_size
**
2
-
0.1
image_t
=
torch
.
concat
([
image_t
,
image_t
,
image_t
],
dim
=
1
).
to
(
self
.
dtype
)
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"after.png"
)
...
...
@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
qmodel_device
=
kwargs
.
pop
(
"qmodel_device"
,
"cuda:0"
)
qmodel_device
=
torch
.
device
(
qmodel_device
)
if
qmodel_device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"qmodel_device =
{
qmodel_device
}
is not a CUDA device"
)
qmodel_path
=
kwargs
.
pop
(
"qmodel_path"
,
None
)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
...
...
@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
if
qmodel_path
is
not
None
:
assert
isinstance
(
qmodel_path
,
str
)
if
not
os
.
path
.
exists
(
qmodel_path
):
hf_repo_id
=
os
.
path
.
dirname
(
qmodel_path
)
filename
=
os
.
path
.
basename
(
qmodel_path
)
qmodel_path
=
hf_hub_download
(
repo_id
=
hf_repo_id
,
filename
=
filename
)
m
=
load_quantized_model
(
qmodel_path
,
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
)
inject_pipeline
(
pipeline
,
m
)
qmodel_path
=
snapshot_download
(
qmodel_path
)
m
=
load_quantized_model
(
os
.
path
.
join
(
qmodel_path
,
"transformer_blocks.safetensors"
),
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
,
)
inject_pipeline
(
pipeline
,
m
,
qmodel_device
)
pipeline
.
precision
=
"int4"
if
qencoder_path
is
not
None
:
...
...
app/i2i/run_gradio.py
View file @
df981d24
...
...
@@ -5,15 +5,17 @@ import tempfile
import
time
import
GPUtil
import
gradio
as
gr
import
numpy
as
np
import
torch
from
PIL
import
Image
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLES
,
STYLE_NAMES
import
numpy
as
np
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
...
...
@@ -30,7 +32,7 @@ else:
pipeline
=
FluxPix2pixTurboPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
args
.
use_qencoder
else
None
,
)
pipeline
=
pipeline
.
to
(
"cuda"
)
...
...
app/i2i/utils.py
View file @
df981d24
import
cv2
import
numpy
as
np
from
PIL
import
Image
import
argparse
...
...
app/t2i/data/DCI/DCI.py
View file @
df981d24
...
...
@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE
=
"Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL
=
"https://
scontent.xx.fbcdn.n
et/m
1/v/t6/An_zz_Te0EtVC_cHtUwnyNKODapWXuNNPeBgZn_3XY8yDFzwHrNb-zwN9mYCbAeWUKQooCI9mVbwvzZDZzDUlscRjYxLKsw.tar?ccb=10-5&oh=00_AYBnKR-fSIir-E49Q7-qO2tjmY0BGJhCciHS__B5QyiBAg&oe=673FFA8A&_nc_sid=0fdd51
"
IMAGE_URL
=
"https://
huggingface.co/datas
et
s
/m
it-han-lab/svdquant-datasets/resolve/main/sDCI.gz
"
PROMPT_URLS
=
{
"sDCI"
:
"https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"
}
...
...
app/t2i/get_metrics.py
View file @
df981d24
...
...
@@ -31,7 +31,9 @@ def main():
results
=
{}
dataset_names
=
sorted
(
os
.
listdir
(
image_root1
))
for
dataset_name
in
dataset_names
:
print
(
"##Results for dataset:"
,
dataset_name
)
if
image_root2
is
not
None
and
dataset_name
not
in
os
.
listdir
(
image_root2
):
continue
print
(
"Results for dataset:"
,
dataset_name
)
results
[
dataset_name
]
=
{}
dataset
=
get_dataset
(
name
=
dataset_name
,
return_gt
=
True
)
fid
=
compute_fid
(
ref_dirpath_or_dataset
=
dataset
,
gen_dirpath
=
os
.
path
.
join
(
image_root1
,
dataset_name
))
...
...
app/t2i/latency.py
View file @
df981d24
...
...
@@ -2,6 +2,7 @@ import argparse
import
time
import
torch
from
torch
import
nn
from
tqdm
import
trange
from
utils
import
get_pipeline
...
...
@@ -51,23 +52,38 @@ def main():
pipeline
.
set_progress_bar_config
(
position
=
1
,
desc
=
"Step"
,
leave
=
False
)
for
_
in
trange
(
args
.
warmup_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
pipeline
(
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
)
torch
.
cuda
.
synchronize
()
for
_
in
trange
(
args
.
test_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
start_time
=
time
.
time
()
pipeline
(
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
latency_list
.
append
(
end_time
-
start_time
)
elif
args
.
mode
==
"step"
:
pass
inputs
=
{}
def
get_input_hook
(
module
:
nn
.
Module
,
input_args
,
input_kwargs
):
inputs
[
"args"
]
=
input_args
inputs
[
"kwargs"
]
=
input_kwargs
pipeline
.
transformer
.
register_forward_pre_hook
(
get_input_hook
,
with_kwargs
=
True
)
pipeline
(
prompt
=
dummy_prompt
,
num_inference_steps
=
1
,
guidance_scale
=
args
.
guidance_scale
,
output_type
=
"latent"
)
for
_
in
trange
(
args
.
warmup_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
pipeline
.
transformer
(
*
inputs
[
"args"
],
**
inputs
[
"kwargs"
])
torch
.
cuda
.
synchronize
()
for
_
in
trange
(
args
.
test_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
start_time
=
time
.
time
()
pipeline
.
transformer
(
*
inputs
[
"args"
],
**
inputs
[
"kwargs"
])
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
latency_list
.
append
(
end_time
-
start_time
)
latency_list
=
sorted
(
latency_list
)
ignored_count
=
int
(
args
.
ignore_ratio
*
len
(
latency_list
)
/
2
)
if
ignored_count
>
0
:
...
...
app/t2i/run_gradio.py
View file @
df981d24
...
...
@@ -5,7 +5,6 @@ import random
import
time
import
GPUtil
import
gradio
as
gr
import
spaces
import
torch
from
peft.tuners
import
lora
...
...
@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from
utils
import
get_pipeline
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
def
get_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
()
...
...
app/t2i/utils.py
View file @
df981d24
...
...
@@ -29,7 +29,7 @@ def get_pipeline(
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qmodel_device
=
device
,
)
...
...
@@ -41,7 +41,7 @@ def get_pipeline(
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-dev
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-dev"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qmodel_device
=
device
,
)
...
...
nunchaku/models/flux.py
View file @
df981d24
...
...
@@ -14,10 +14,11 @@ SVD_RANK = 32
class
NunchakuFluxModel
(
nn
.
Module
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
torch
.
device
):
super
().
__init__
()
self
.
m
=
m
self
.
dtype
=
torch
.
bfloat16
self
.
device
=
device
def
forward
(
self
,
...
...
@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module):
img_tokens
=
hidden_states
.
shape
[
1
]
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
hidden_states
=
hidden_states
.
to
(
self
.
dtype
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
)
temb
=
temb
.
to
(
self
.
dtype
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
...
...
@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module):
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
)
hidden_states
=
hidden_states
.
to
(
original_dtype
)
hidden_states
=
hidden_states
.
to
(
original_dtype
)
.
to
(
original_device
)
encoder_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
hidden_states
=
hidden_states
[:,
txt_tokens
:,
...]
...
...
@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
return
m
def
inject_pipeline
(
pipe
:
FluxPipeline
,
m
:
QuantizedFluxModel
)
->
FluxPipeline
:
def
inject_pipeline
(
pipe
:
FluxPipeline
,
m
:
QuantizedFluxModel
,
device
:
torch
.
device
)
->
FluxPipeline
:
net
:
FluxTransformer2DModel
=
pipe
.
transformer
net
.
pos_embed
=
EmbedND
(
dim
=
net
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuFluxModel
(
m
)])
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuFluxModel
(
m
,
device
)])
net
.
single_transformer_blocks
=
torch
.
nn
.
ModuleList
([])
def
update_params
(
self
:
FluxTransformer2DModel
,
path
:
str
):
...
...
nunchaku/pipelines/flux.py
View file @
df981d24
import
os
import
torch
from
diffusers
import
FluxPipeline
from
huggingface_hub
import
hf_hub_download
from
diffusers
import
__version__
from
diffusers
import
FluxPipeline
,
FluxTransformer2DModel
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
from
torch
import
nn
from
..models.flux
import
inject_pipeline
,
load_quantized_model
...
...
@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
if
not
os
.
path
.
exists
(
qmodel_path
):
hf_repo_id
=
os
.
path
.
dirname
(
qmodel_path
)
filename
=
os
.
path
.
basename
(
qmodel_path
)
qmodel_path
=
hf_hub_download
(
repo_id
=
hf_repo_id
,
filename
=
filename
)
qmodel_path
=
snapshot_download
(
qmodel_path
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
m
=
load_quantized_model
(
qmodel_path
,
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
)
inject_pipeline
(
pipeline
,
m
)
assert
kwargs
.
pop
(
"transformer"
,
None
)
is
None
config
,
unused_kwargs
,
commit_hash
=
FluxTransformer2DModel
.
load_config
(
pretrained_model_name_or_path
,
cache_dir
=
kwargs
.
get
(
"cache_dir"
,
None
),
return_unused_kwargs
=
True
,
return_commit_hash
=
True
,
force_download
=
kwargs
.
get
(
"force_download"
,
False
),
proxies
=
kwargs
.
get
(
"proxies"
,
None
),
local_files_only
=
kwargs
.
get
(
"local_files_only"
,
None
),
token
=
kwargs
.
get
(
"token"
,
None
),
revision
=
kwargs
.
get
(
"revision"
,
None
),
subfolder
=
"transformer"
,
user_agent
=
{
"diffusers"
:
__version__
,
"file_type"
:
"model"
,
"framework"
:
"pytorch"
},
**
kwargs
,
)
transformer
:
nn
.
Module
=
FluxTransformer2DModel
.
from_config
(
config
).
to
(
kwargs
.
get
(
"torch_dtype"
,
torch
.
bfloat16
))
state_dict
=
load_file
(
os
.
path
.
join
(
qmodel_path
,
"unquantized_layers.safetensors"
))
transformer
.
load_state_dict
(
state_dict
,
strict
=
False
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
transformer
=
transformer
,
**
kwargs
)
m
=
load_quantized_model
(
os
.
path
.
join
(
qmodel_path
,
"transformer_blocks.safetensors"
),
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
,
)
inject_pipeline
(
pipeline
,
m
,
qmodel_device
)
if
qencoder_path
is
not
None
:
assert
isinstance
(
qencoder_path
,
str
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment