Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
df981d24
"vscode:/vscode.git/clone" did not exist on "4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad"
Commit
df981d24
authored
Nov 28, 2024
by
muyangli
Browse files
[major] fix the evaluation scripts; no need to download the entire model
parent
25ce8942
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
95 additions
and
48 deletions
+95
-48
README.md
README.md
+4
-4
app/i2i/flux_pix2pix_pipeline.py
app/i2i/flux_pix2pix_pipeline.py
+11
-10
app/i2i/run_gradio.py
app/i2i/run_gradio.py
+6
-4
app/i2i/utils.py
app/i2i/utils.py
+0
-3
app/t2i/data/DCI/DCI.py
app/t2i/data/DCI/DCI.py
+1
-1
app/t2i/get_metrics.py
app/t2i/get_metrics.py
+3
-1
app/t2i/latency.py
app/t2i/latency.py
+23
-7
app/t2i/run_gradio.py
app/t2i/run_gradio.py
+3
-1
app/t2i/utils.py
app/t2i/utils.py
+2
-2
nunchaku/models/flux.py
nunchaku/models/flux.py
+10
-7
nunchaku/pipelines/flux.py
nunchaku/pipelines/flux.py
+32
-8
No files found.
README.md
View file @
df981d24
...
@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
...
@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
Then build the package from source:
Then build the package from source:
```
shell
```
shell
git clone https://github.com/mit-han-lab/nunchaku.git
git clone https://github.com/mit-han-lab/nunchaku.git
cd
nunchaku
cd
nunchaku
git submodule init
git submodule init
git submodule update
git submodule update
pip
install
-e
.
pip
install
-e
.
```
```
...
@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
...
@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline
=
nunchaku_flux
.
from_pretrained
(
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
# download from Huggingface
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
# download from Huggingface
).
to
(
"cuda"
)
).
to
(
"cuda"
)
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
=
pipeline
(
"A cat holding a sign that says hello world"
,
num_inference_steps
=
4
,
guidance_scale
=
0
).
images
[
0
]
image
.
save
(
"example.png"
)
image
.
save
(
"example.png"
)
...
...
app/i2i/flux_pix2pix_pipeline.py
View file @
df981d24
...
@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union
...
@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union
import
torch
import
torch
import
torchvision.transforms.functional
as
F
import
torchvision.transforms.functional
as
F
import
torchvision.utils
import
torchvision.utils
from
PIL
import
Image
from
diffusers.pipelines.flux.pipeline_flux
import
FluxPipeline
,
FluxPipelineOutput
,
FluxTransformer2DModel
from
diffusers.pipelines.flux.pipeline_flux
import
FluxPipeline
,
FluxPipelineOutput
,
FluxTransformer2DModel
from
einops
import
rearrange
from
einops
import
rearrange
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
peft.tuners
import
lora
from
peft.tuners
import
lora
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
nunchaku.models.flux
import
inject_pipeline
,
load_quantized_model
from
nunchaku.models.flux
import
inject_pipeline
,
load_quantized_model
...
@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
...
@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self
.
erosion_kernel
=
erosion_kernel
self
.
erosion_kernel
=
erosion_kernel
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"before.png"
)
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"before.png"
)
image_t
=
(
image_t
=
nn
.
functional
.
conv2d
(
image_t
[:,
:
1
],
erosion_kernel
,
padding
=
kernel_size
//
2
)
>
kernel_size
**
2
-
0.1
nn
.
functional
.
conv2d
(
image_t
[:,
:
1
],
erosion_kernel
,
padding
=
kernel_size
//
2
)
>
kernel_size
**
2
-
0.1
)
image_t
=
torch
.
concat
([
image_t
,
image_t
,
image_t
],
dim
=
1
).
to
(
self
.
dtype
)
image_t
=
torch
.
concat
([
image_t
,
image_t
,
image_t
],
dim
=
1
).
to
(
self
.
dtype
)
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"after.png"
)
torchvision
.
utils
.
save_image
(
image_t
[
0
],
"after.png"
)
...
@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
...
@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
qmodel_device
=
kwargs
.
pop
(
"qmodel_device"
,
"cuda:0"
)
qmodel_device
=
kwargs
.
pop
(
"qmodel_device"
,
"cuda:0"
)
qmodel_device
=
torch
.
device
(
qmodel_device
)
qmodel_device
=
torch
.
device
(
qmodel_device
)
if
qmodel_device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"qmodel_device =
{
qmodel_device
}
is not a CUDA device"
)
qmodel_path
=
kwargs
.
pop
(
"qmodel_path"
,
None
)
qmodel_path
=
kwargs
.
pop
(
"qmodel_path"
,
None
)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
...
@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
...
@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
if
qmodel_path
is
not
None
:
if
qmodel_path
is
not
None
:
assert
isinstance
(
qmodel_path
,
str
)
assert
isinstance
(
qmodel_path
,
str
)
if
not
os
.
path
.
exists
(
qmodel_path
):
if
not
os
.
path
.
exists
(
qmodel_path
):
hf_repo_id
=
os
.
path
.
dirname
(
qmodel_path
)
qmodel_path
=
snapshot_download
(
qmodel_path
)
filename
=
os
.
path
.
basename
(
qmodel_path
)
m
=
load_quantized_model
(
qmodel_path
=
hf_hub_download
(
repo_id
=
hf_repo_id
,
filename
=
filename
)
os
.
path
.
join
(
qmodel_path
,
"transformer_blocks.safetensors"
),
m
=
load_quantized_model
(
qmodel_path
,
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
)
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
,
inject_pipeline
(
pipeline
,
m
)
)
inject_pipeline
(
pipeline
,
m
,
qmodel_device
)
pipeline
.
precision
=
"int4"
pipeline
.
precision
=
"int4"
if
qencoder_path
is
not
None
:
if
qencoder_path
is
not
None
:
...
...
app/i2i/run_gradio.py
View file @
df981d24
...
@@ -5,15 +5,17 @@ import tempfile
...
@@ -5,15 +5,17 @@ import tempfile
import
time
import
time
import
GPUtil
import
GPUtil
import
gradio
as
gr
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
flux_pix2pix_pipeline
import
FluxPix2pixTurboPipeline
from
nunchaku.models.safety_checker
import
SafetyChecker
from
nunchaku.models.safety_checker
import
SafetyChecker
from
utils
import
get_args
from
utils
import
get_args
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLES
,
STYLE_NAMES
from
vars
import
DEFAULT_SKETCH_GUIDANCE
,
DEFAULT_STYLE_NAME
,
MAX_SEED
,
STYLE_NAMES
,
STYLES
import
numpy
as
np
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
blank_image
=
Image
.
new
(
"RGB"
,
(
1024
,
1024
),
(
255
,
255
,
255
))
...
@@ -30,7 +32,7 @@ else:
...
@@ -30,7 +32,7 @@ else:
pipeline
=
FluxPix2pixTurboPipeline
.
from_pretrained
(
pipeline
=
FluxPix2pixTurboPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
args
.
use_qencoder
else
None
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
args
.
use_qencoder
else
None
,
)
)
pipeline
=
pipeline
.
to
(
"cuda"
)
pipeline
=
pipeline
.
to
(
"cuda"
)
...
...
app/i2i/utils.py
View file @
df981d24
import
cv2
import
numpy
as
np
from
PIL
import
Image
import
argparse
import
argparse
...
...
app/t2i/data/DCI/DCI.py
View file @
df981d24
...
@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI"
...
@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE
=
"Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
_LICENSE
=
"Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL
=
"https://
scontent.xx.fbcdn.n
et/m
1/v/t6/An_zz_Te0EtVC_cHtUwnyNKODapWXuNNPeBgZn_3XY8yDFzwHrNb-zwN9mYCbAeWUKQooCI9mVbwvzZDZzDUlscRjYxLKsw.tar?ccb=10-5&oh=00_AYBnKR-fSIir-E49Q7-qO2tjmY0BGJhCciHS__B5QyiBAg&oe=673FFA8A&_nc_sid=0fdd51
"
IMAGE_URL
=
"https://
huggingface.co/datas
et
s
/m
it-han-lab/svdquant-datasets/resolve/main/sDCI.gz
"
PROMPT_URLS
=
{
"sDCI"
:
"https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"
}
PROMPT_URLS
=
{
"sDCI"
:
"https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"
}
...
...
app/t2i/get_metrics.py
View file @
df981d24
...
@@ -31,7 +31,9 @@ def main():
...
@@ -31,7 +31,9 @@ def main():
results
=
{}
results
=
{}
dataset_names
=
sorted
(
os
.
listdir
(
image_root1
))
dataset_names
=
sorted
(
os
.
listdir
(
image_root1
))
for
dataset_name
in
dataset_names
:
for
dataset_name
in
dataset_names
:
print
(
"##Results for dataset:"
,
dataset_name
)
if
image_root2
is
not
None
and
dataset_name
not
in
os
.
listdir
(
image_root2
):
continue
print
(
"Results for dataset:"
,
dataset_name
)
results
[
dataset_name
]
=
{}
results
[
dataset_name
]
=
{}
dataset
=
get_dataset
(
name
=
dataset_name
,
return_gt
=
True
)
dataset
=
get_dataset
(
name
=
dataset_name
,
return_gt
=
True
)
fid
=
compute_fid
(
ref_dirpath_or_dataset
=
dataset
,
gen_dirpath
=
os
.
path
.
join
(
image_root1
,
dataset_name
))
fid
=
compute_fid
(
ref_dirpath_or_dataset
=
dataset
,
gen_dirpath
=
os
.
path
.
join
(
image_root1
,
dataset_name
))
...
...
app/t2i/latency.py
View file @
df981d24
...
@@ -2,6 +2,7 @@ import argparse
...
@@ -2,6 +2,7 @@ import argparse
import
time
import
time
import
torch
import
torch
from
torch
import
nn
from
tqdm
import
trange
from
tqdm
import
trange
from
utils
import
get_pipeline
from
utils
import
get_pipeline
...
@@ -51,23 +52,38 @@ def main():
...
@@ -51,23 +52,38 @@ def main():
pipeline
.
set_progress_bar_config
(
position
=
1
,
desc
=
"Step"
,
leave
=
False
)
pipeline
.
set_progress_bar_config
(
position
=
1
,
desc
=
"Step"
,
leave
=
False
)
for
_
in
trange
(
args
.
warmup_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
for
_
in
trange
(
args
.
warmup_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
pipeline
(
pipeline
(
prompt
=
dummy_prompt
,
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
for
_
in
trange
(
args
.
test_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
for
_
in
trange
(
args
.
test_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
pipeline
(
pipeline
(
prompt
=
dummy_prompt
,
prompt
=
dummy_prompt
,
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
num_inference_steps
=
args
.
num_inference_steps
,
guidance_scale
=
args
.
guidance_scale
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
end_time
=
time
.
time
()
latency_list
.
append
(
end_time
-
start_time
)
latency_list
.
append
(
end_time
-
start_time
)
elif
args
.
mode
==
"step"
:
elif
args
.
mode
==
"step"
:
pass
inputs
=
{}
def
get_input_hook
(
module
:
nn
.
Module
,
input_args
,
input_kwargs
):
inputs
[
"args"
]
=
input_args
inputs
[
"kwargs"
]
=
input_kwargs
pipeline
.
transformer
.
register_forward_pre_hook
(
get_input_hook
,
with_kwargs
=
True
)
pipeline
(
prompt
=
dummy_prompt
,
num_inference_steps
=
1
,
guidance_scale
=
args
.
guidance_scale
,
output_type
=
"latent"
)
for
_
in
trange
(
args
.
warmup_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
pipeline
.
transformer
(
*
inputs
[
"args"
],
**
inputs
[
"kwargs"
])
torch
.
cuda
.
synchronize
()
for
_
in
trange
(
args
.
test_times
,
desc
=
"Warmup"
,
position
=
0
,
leave
=
False
):
start_time
=
time
.
time
()
pipeline
.
transformer
(
*
inputs
[
"args"
],
**
inputs
[
"kwargs"
])
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
latency_list
.
append
(
end_time
-
start_time
)
latency_list
=
sorted
(
latency_list
)
latency_list
=
sorted
(
latency_list
)
ignored_count
=
int
(
args
.
ignore_ratio
*
len
(
latency_list
)
/
2
)
ignored_count
=
int
(
args
.
ignore_ratio
*
len
(
latency_list
)
/
2
)
if
ignored_count
>
0
:
if
ignored_count
>
0
:
...
...
app/t2i/run_gradio.py
View file @
df981d24
...
@@ -5,7 +5,6 @@ import random
...
@@ -5,7 +5,6 @@ import random
import
time
import
time
import
GPUtil
import
GPUtil
import
gradio
as
gr
import
spaces
import
spaces
import
torch
import
torch
from
peft.tuners
import
lora
from
peft.tuners
import
lora
...
@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
...
@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from
utils
import
get_pipeline
from
utils
import
get_pipeline
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
from
vars
import
DEFAULT_HEIGHT
,
DEFAULT_WIDTH
,
EXAMPLES
,
MAX_SEED
,
PROMPT_TEMPLATES
,
SVDQ_LORA_PATHS
# import gradio last to avoid conflicts with other imports
import
gradio
as
gr
def
get_args
()
->
argparse
.
Namespace
:
def
get_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
app/t2i/utils.py
View file @
df981d24
...
@@ -29,7 +29,7 @@ def get_pipeline(
...
@@ -29,7 +29,7 @@ def get_pipeline(
pipeline
=
nunchaku_flux
.
from_pretrained
(
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
"black-forest-labs/FLUX.1-schnell"
,
torch_dtype
=
torch
.
bfloat16
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-schnell
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-schnell"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qmodel_device
=
device
,
qmodel_device
=
device
,
)
)
...
@@ -41,7 +41,7 @@ def get_pipeline(
...
@@ -41,7 +41,7 @@ def get_pipeline(
pipeline
=
nunchaku_flux
.
from_pretrained
(
pipeline
=
nunchaku_flux
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
"black-forest-labs/FLUX.1-dev"
,
torch_dtype
=
torch
.
bfloat16
,
torch_dtype
=
torch
.
bfloat16
,
qmodel_path
=
"mit-han-lab/
svdquant-models/
svdq-int4-flux.1-dev
.safetensors
"
,
qmodel_path
=
"mit-han-lab/svdq-int4-flux.1-dev"
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qencoder_path
=
"mit-han-lab/svdquant-models/svdq-w4a16-t5.pt"
if
use_qencoder
else
None
,
qmodel_device
=
device
,
qmodel_device
=
device
,
)
)
...
...
nunchaku/models/flux.py
View file @
df981d24
...
@@ -14,10 +14,11 @@ SVD_RANK = 32
...
@@ -14,10 +14,11 @@ SVD_RANK = 32
class
NunchakuFluxModel
(
nn
.
Module
):
class
NunchakuFluxModel
(
nn
.
Module
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
torch
.
device
):
super
().
__init__
()
super
().
__init__
()
self
.
m
=
m
self
.
m
=
m
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
device
=
device
def
forward
(
def
forward
(
self
,
self
,
...
@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module):
...
@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module):
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
original_dtype
=
hidden_states
.
dtype
original_dtype
=
hidden_states
.
dtype
original_device
=
hidden_states
.
device
hidden_states
=
hidden_states
.
to
(
self
.
dtype
)
hidden_states
=
hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
)
encoder_hidden_states
=
encoder_hidden_states
.
to
(
self
.
dtype
).
to
(
self
.
device
)
temb
=
temb
.
to
(
self
.
dtype
)
temb
=
temb
.
to
(
self
.
dtype
).
to
(
self
.
device
)
image_rotary_emb
=
image_rotary_emb
.
to
(
self
.
device
)
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
ndim
==
6
assert
image_rotary_emb
.
shape
[
0
]
==
1
assert
image_rotary_emb
.
shape
[
0
]
==
1
...
@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module):
...
@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module):
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
,
rotary_emb_single
)
)
hidden_states
=
hidden_states
.
to
(
original_dtype
)
hidden_states
=
hidden_states
.
to
(
original_dtype
)
.
to
(
original_device
)
encoder_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
encoder_hidden_states
=
hidden_states
[:,
:
txt_tokens
,
...]
hidden_states
=
hidden_states
[:,
txt_tokens
:,
...]
hidden_states
=
hidden_states
[:,
txt_tokens
:,
...]
...
@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
...
@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
return
m
return
m
def
inject_pipeline
(
pipe
:
FluxPipeline
,
m
:
QuantizedFluxModel
)
->
FluxPipeline
:
def
inject_pipeline
(
pipe
:
FluxPipeline
,
m
:
QuantizedFluxModel
,
device
:
torch
.
device
)
->
FluxPipeline
:
net
:
FluxTransformer2DModel
=
pipe
.
transformer
net
:
FluxTransformer2DModel
=
pipe
.
transformer
net
.
pos_embed
=
EmbedND
(
dim
=
net
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
net
.
pos_embed
=
EmbedND
(
dim
=
net
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuFluxModel
(
m
)])
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuFluxModel
(
m
,
device
)])
net
.
single_transformer_blocks
=
torch
.
nn
.
ModuleList
([])
net
.
single_transformer_blocks
=
torch
.
nn
.
ModuleList
([])
def
update_params
(
self
:
FluxTransformer2DModel
,
path
:
str
):
def
update_params
(
self
:
FluxTransformer2DModel
,
path
:
str
):
...
...
nunchaku/pipelines/flux.py
View file @
df981d24
import
os
import
os
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
__version__
from
huggingface_hub
import
hf_hub_download
from
diffusers
import
FluxPipeline
,
FluxTransformer2DModel
from
huggingface_hub
import
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
from
torch
import
nn
from
..models.flux
import
inject_pipeline
,
load_quantized_model
from
..models.flux
import
inject_pipeline
,
load_quantized_model
...
@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
...
@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
qencoder_path
=
kwargs
.
pop
(
"qencoder_path"
,
None
)
if
not
os
.
path
.
exists
(
qmodel_path
):
if
not
os
.
path
.
exists
(
qmodel_path
):
hf_repo_id
=
os
.
path
.
dirname
(
qmodel_path
)
qmodel_path
=
snapshot_download
(
qmodel_path
)
filename
=
os
.
path
.
basename
(
qmodel_path
)
qmodel_path
=
hf_hub_download
(
repo_id
=
hf_repo_id
,
filename
=
filename
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
assert
kwargs
.
pop
(
"transformer"
,
None
)
is
None
m
=
load_quantized_model
(
qmodel_path
,
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
)
inject_pipeline
(
pipeline
,
m
)
config
,
unused_kwargs
,
commit_hash
=
FluxTransformer2DModel
.
load_config
(
pretrained_model_name_or_path
,
cache_dir
=
kwargs
.
get
(
"cache_dir"
,
None
),
return_unused_kwargs
=
True
,
return_commit_hash
=
True
,
force_download
=
kwargs
.
get
(
"force_download"
,
False
),
proxies
=
kwargs
.
get
(
"proxies"
,
None
),
local_files_only
=
kwargs
.
get
(
"local_files_only"
,
None
),
token
=
kwargs
.
get
(
"token"
,
None
),
revision
=
kwargs
.
get
(
"revision"
,
None
),
subfolder
=
"transformer"
,
user_agent
=
{
"diffusers"
:
__version__
,
"file_type"
:
"model"
,
"framework"
:
"pytorch"
},
**
kwargs
,
)
transformer
:
nn
.
Module
=
FluxTransformer2DModel
.
from_config
(
config
).
to
(
kwargs
.
get
(
"torch_dtype"
,
torch
.
bfloat16
))
state_dict
=
load_file
(
os
.
path
.
join
(
qmodel_path
,
"unquantized_layers.safetensors"
))
transformer
.
load_state_dict
(
state_dict
,
strict
=
False
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pretrained_model_name_or_path
,
transformer
=
transformer
,
**
kwargs
)
m
=
load_quantized_model
(
os
.
path
.
join
(
qmodel_path
,
"transformer_blocks.safetensors"
),
0
if
qmodel_device
.
index
is
None
else
qmodel_device
.
index
,
)
inject_pipeline
(
pipeline
,
m
,
qmodel_device
)
if
qencoder_path
is
not
None
:
if
qencoder_path
is
not
None
:
assert
isinstance
(
qencoder_path
,
str
)
assert
isinstance
(
qencoder_path
,
str
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment