Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
3ef186fd
Commit
3ef186fd
authored
Mar 26, 2025
by
Muyang Li
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Multiple LoRAs
parent
ca1a2e90
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1353 additions
and
776 deletions
+1353
-776
app/flux.1/t2i/data/DCI/DCI.py
app/flux.1/t2i/data/DCI/DCI.py
+2
-2
app/flux.1/t2i/data/MJHQ/MJHQ.py
app/flux.1/t2i/data/MJHQ/MJHQ.py
+2
-2
examples/int4-flux.1-canny-dev-lora.py
examples/int4-flux.1-canny-dev-lora.py
+31
-0
examples/int4-flux.1-depth-dev-lora.py
examples/int4-flux.1-depth-dev-lora.py
+34
-0
examples/int4-flux.1-dev-lora.py
examples/int4-flux.1-dev-lora.py
+2
-2
nunchaku/lora/flux/__init__.py
nunchaku/lora/flux/__init__.py
+3
-4
nunchaku/lora/flux/comfyui_converter.py
nunchaku/lora/flux/comfyui_converter.py
+0
-124
nunchaku/lora/flux/compose.py
nunchaku/lora/flux/compose.py
+141
-0
nunchaku/lora/flux/convert.py
nunchaku/lora/flux/convert.py
+14
-47
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+25
-405
nunchaku/lora/flux/nunchaku_converter.py
nunchaku/lora/flux/nunchaku_converter.py
+531
-0
nunchaku/lora/flux/packer.py
nunchaku/lora/flux/packer.py
+298
-0
nunchaku/lora/flux/utils.py
nunchaku/lora/flux/utils.py
+34
-11
nunchaku/lora/flux/xlab_converter.py
nunchaku/lora/flux/xlab_converter.py
+0
-57
nunchaku/models/text_encoders/tinychat_utils.py
nunchaku/models/text_encoders/tinychat_utils.py
+1
-1
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+183
-47
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+20
-60
nunchaku/models/transformers/utils.py
nunchaku/models/transformers/utils.py
+29
-11
pyproject.toml
pyproject.toml
+1
-1
tests/data/MJHQ/MJHQ.py
tests/data/MJHQ/MJHQ.py
+2
-2
No files found.
app/flux.1/t2i/data/DCI/DCI.py
View file @
3ef186fd
...
@@ -17,8 +17,8 @@ _CITATION = """\
...
@@ -17,8 +17,8 @@ _CITATION = """\
"""
"""
_DESCRIPTION
=
"""
\
_DESCRIPTION
=
"""
\
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
each with a complete description aiming to capture the full visual detail of what is present in the image.
each with a complete description aiming to capture the full visual detail of what is present in the image.
Much of the description is directly aligned to submasks of the image.
Much of the description is directly aligned to submasks of the image.
"""
"""
...
...
app/flux.1/t2i/data/MJHQ/MJHQ.py
View file @
3ef186fd
...
@@ -7,7 +7,7 @@ from PIL import Image
...
@@ -7,7 +7,7 @@ from PIL import Image
_CITATION
=
"""
\
_CITATION
=
"""
\
@misc{li2024playground,
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
year={2024},
eprint={2402.17245},
eprint={2402.17245},
...
@@ -17,7 +17,7 @@ _CITATION = """\
...
@@ -17,7 +17,7 @@ _CITATION = """\
"""
"""
_DESCRIPTION
=
"""
\
_DESCRIPTION
=
"""
\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
"""
...
...
examples/int4-flux.1-canny-dev-lora.py
0 → 100644
View file @
3ef186fd
import
torch
from
controlnet_aux
import
CannyDetector
from
diffusers
import
FluxControlPipeline
from
diffusers.utils
import
load_image
from
nunchaku
import
NunchakuFluxTransformer2dModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
)
pipe
=
FluxControlPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
### LoRA Related Code ###
transformer
.
update_lora_params
(
"black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors"
)
# Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer
.
set_lora_strength
(
0.85
)
# Your LoRA strength here
### End of LoRA Related Code ###
prompt
=
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image
=
load_image
(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor
=
CannyDetector
()
control_image
=
processor
(
control_image
,
low_threshold
=
50
,
high_threshold
=
200
,
detect_resolution
=
1024
,
image_resolution
=
1024
)
image
=
pipe
(
prompt
=
prompt
,
control_image
=
control_image
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
50
,
guidance_scale
=
30.0
).
images
[
0
]
image
.
save
(
"int4-flux.1-canny-dev-lora.png"
)
examples/int4-flux.1-depth-dev-lora.py
0 → 100644
View file @
3ef186fd
import
torch
from
diffusers
import
FluxControlPipeline
from
diffusers.utils
import
load_image
from
image_gen_aux
import
DepthPreprocessor
from
nunchaku
import
NunchakuFluxTransformer2dModel
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-dev"
)
pipe
=
FluxControlPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-dev"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
### LoRA Related Code ###
transformer
.
update_lora_params
(
"black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors"
)
# Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer
.
set_lora_strength
(
0.85
)
# Your LoRA strength here
### End of LoRA Related Code ###
control_image
=
load_image
(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
)
processor
=
DepthPreprocessor
.
from_pretrained
(
"LiheYoung/depth-anything-large-hf"
)
control_image
=
processor
(
control_image
)[
0
].
convert
(
"RGB"
)
image
=
pipe
(
prompt
=
"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
,
control_image
=
control_image
,
height
=
1024
,
width
=
1024
,
num_inference_steps
=
30
,
guidance_scale
=
10.0
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
image
.
save
(
"int4-flux.1-depth-dev-lora.png"
)
examples/int4-flux.1-dev-lora.py
View file @
3ef186fd
...
@@ -10,8 +10,8 @@ pipeline = FluxPipeline.from_pretrained(
...
@@ -10,8 +10,8 @@ pipeline = FluxPipeline.from_pretrained(
### LoRA Related Code ###
### LoRA Related Code ###
transformer
.
update_lora_params
(
transformer
.
update_lora_params
(
"
mit-han-lab/svdquant-lora-collection/svdq-int4-flux.1-dev-ghibsky
.safetensors"
"
aleksa-codes/flux-ghibsky-illustration/lora
.safetensors"
)
# Path to your
converted
LoRA safetensors, can also be a remote HuggingFace path
)
# Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer
.
set_lora_strength
(
1
)
# Your LoRA strength here
transformer
.
set_lora_strength
(
1
)
# Your LoRA strength here
### End of LoRA Related Code ###
### End of LoRA Related Code ###
...
...
nunchaku/lora/flux/__init__.py
View file @
3ef186fd
from
.comfyui_converter
import
comfyui2diffusers
from
.diffusers_converter
import
to_diffusers
from
.diffusers_converter
import
convert_to_nunchaku_flux_lowrank_dict
from
.nunchaku_converter
import
convert_to_nunchaku_flux_lowrank_dict
,
to_nunchaku
from
.utils
import
detect_format
from
.utils
import
is_nunchaku_format
from
.xlab_converter
import
xlab2diffusers
nunchaku/lora/flux/comfyui_converter.py
deleted
100644 → 0
View file @
ca1a2e90
# convert the comfyui lora to diffusers format
import
argparse
import
os
import
torch
from
safetensors.torch
import
save_file
from
...utils
import
load_state_dict_in_safetensors
def
comfyui2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
,
min_rank
:
int
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
else
:
tensors
=
input_lora
new_tensors
=
{}
max_rank
=
0
for
k
,
v
in
tensors
.
items
():
if
"alpha"
in
k
or
"lora_te"
in
k
:
continue
new_k
=
k
.
replace
(
"lora_down"
,
"lora_A"
).
replace
(
"lora_up"
,
"lora_B"
)
if
"lora_unet_double_blocks_"
in
k
:
new_k
=
new_k
.
replace
(
"lora_unet_double_blocks_"
,
"transformer.transformer_blocks."
)
if
"qkv"
in
new_k
:
for
i
,
p
in
enumerate
([
"q"
,
"k"
,
"v"
]):
if
"lora_A"
in
new_k
:
# Copy the tensor
new_k
=
new_k
.
replace
(
"_img_attn_qkv"
,
f
".attn.to_
{
p
}
"
)
new_k
=
new_k
.
replace
(
"_txt_attn_qkv"
,
f
".attn.add_
{
p
}
_proj"
)
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
assert
"lora_B"
in
new_k
assert
v
.
shape
[
0
]
%
3
==
0
chunk_size
=
v
.
shape
[
0
]
//
3
new_k
=
new_k
.
replace
(
"_img_attn_qkv"
,
f
".attn.to_
{
p
}
"
)
new_k
=
new_k
.
replace
(
"_txt_attn_qkv"
,
f
".attn.add_
{
p
}
_proj"
)
new_tensors
[
new_k
]
=
v
[
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
]
else
:
new_k
=
new_k
.
replace
(
"_img_attn_proj"
,
".attn.to_out.0"
)
new_k
=
new_k
.
replace
(
"_img_mlp_0"
,
".ff.net.0.proj"
)
new_k
=
new_k
.
replace
(
"_img_mlp_2"
,
".ff.net.2"
)
new_k
=
new_k
.
replace
(
"_img_mod_lin"
,
".norm1.linear"
)
new_k
=
new_k
.
replace
(
"_txt_attn_proj"
,
".attn.to_add_out"
)
new_k
=
new_k
.
replace
(
"_txt_mlp_0"
,
".ff_context.net.0.proj"
)
new_k
=
new_k
.
replace
(
"_txt_mlp_2"
,
".ff_context.net.2"
)
new_k
=
new_k
.
replace
(
"_txt_mod_lin"
,
".norm1_context.linear"
)
if
"lora_down"
in
k
:
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
rank
=
v
.
shape
[
0
]
v
=
v
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
new_tensors
[
new_k
]
=
v
else
:
assert
"lora_unet_single_blocks"
in
k
new_k
=
new_k
.
replace
(
"lora_unet_single_blocks_"
,
"transformer.single_transformer_blocks."
)
if
"linear1"
in
k
:
start
=
0
for
i
,
p
in
enumerate
([
"q"
,
"k"
,
"v"
,
"i"
]):
if
"lora_A"
in
new_k
:
if
p
==
"i"
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
else
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
f
".attn.to_
{
p
}
"
)
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k1
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
if
p
==
"i"
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
else
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
f
".attn.to_
{
p
}
"
)
chunk_size
=
12288
if
p
==
"i"
else
3072
new_tensors
[
new_k1
]
=
v
[
start
:
start
+
chunk_size
]
start
+=
chunk_size
else
:
new_k
=
new_k
.
replace
(
"_linear2"
,
".proj_out"
)
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
if
"lora_down"
in
k
:
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
v
=
v
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
new_tensors
[
new_k
]
=
v
if
min_rank
is
not
None
:
for
k
in
new_tensors
.
keys
():
v
=
new_tensors
[
k
]
if
"lora_A"
in
k
:
rank
=
v
.
shape
[
0
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
min_rank
,
v
.
shape
[
1
],
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:
rank
]
=
v
new_tensors
[
k
]
=
new_v
else
:
assert
"lora_B"
in
k
rank
=
v
.
shape
[
1
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
v
.
shape
[
0
],
min_rank
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:,
:
rank
]
=
v
new_tensors
[
k
]
=
new_v
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensor file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensor file"
)
parser
.
add_argument
(
"--min-rank"
,
type
=
int
,
default
=
None
,
help
=
"minimum rank for the LoRA weights"
)
args
=
parser
.
parse_args
()
comfyui2diffusers
(
args
.
input_path
,
args
.
output_path
,
min_rank
=
args
.
min_rank
)
nunchaku/lora/flux/compose.py
0 → 100644
View file @
3ef186fd
import
argparse
import
os
import
torch
from
safetensors.torch
import
save_file
from
.diffusers_converter
import
to_diffusers
from
.utils
import
is_nunchaku_format
def
compose_lora
(
loras
:
list
[
tuple
[
str
|
dict
[
str
,
torch
.
Tensor
],
float
]],
output_path
:
str
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
composed
=
{}
for
lora
,
strength
in
loras
:
assert
not
is_nunchaku_format
(
lora
)
lora
=
to_diffusers
(
lora
)
for
k
,
v
in
list
(
lora
.
items
()):
if
v
.
ndim
==
1
:
previous_tensor
=
composed
.
get
(
k
,
None
)
if
previous_tensor
is
None
:
if
"norm_q"
in
k
or
"norm_k"
in
k
or
"norm_added_q"
in
k
or
"norm_added_k"
in
k
:
composed
[
k
]
=
v
else
:
composed
[
k
]
=
v
*
strength
else
:
assert
not
(
"norm_q"
in
k
or
"norm_k"
in
k
or
"norm_added_q"
in
k
or
"norm_added_k"
in
k
)
composed
[
k
]
=
previous_tensor
+
v
*
strength
else
:
assert
v
.
ndim
==
2
if
"lora_A"
in
k
:
v
=
v
*
strength
if
".to_q."
in
k
or
".add_q_proj."
in
k
:
# qkv must all exist
if
"lora_B"
in
k
:
continue
q_a
=
v
k_a
=
lora
[
k
.
replace
(
".to_q."
,
".to_k."
).
replace
(
".add_q_proj."
,
".add_k_proj."
)]
v_a
=
lora
[
k
.
replace
(
".to_q."
,
".to_v."
).
replace
(
".add_q_proj."
,
".add_v_proj."
)]
q_b
=
lora
[
k
.
replace
(
"lora_A"
,
"lora_B"
)]
k_b
=
lora
[
k
.
replace
(
"lora_A"
,
"lora_B"
)
.
replace
(
".to_q."
,
".to_k."
)
.
replace
(
".add_q_proj."
,
".add_k_proj."
)
]
v_b
=
lora
[
k
.
replace
(
"lora_A"
,
"lora_B"
)
.
replace
(
".to_q."
,
".to_v."
)
.
replace
(
".add_q_proj."
,
".add_v_proj."
)
]
assert
q_a
.
shape
[
0
]
==
k_a
.
shape
[
0
]
==
v_a
.
shape
[
0
]
assert
q_b
.
shape
[
1
]
==
k_b
.
shape
[
1
]
==
v_b
.
shape
[
1
]
if
torch
.
isclose
(
q_a
,
k_a
).
all
()
and
torch
.
isclose
(
q_a
,
v_a
).
all
():
lora_a
=
q_a
lora_b
=
torch
.
cat
((
q_b
,
k_b
,
v_b
),
dim
=
0
)
else
:
lora_a_group
=
(
q_a
,
k_a
,
v_a
)
new_shape_a
=
[
sum
([
_
.
shape
[
0
]
for
_
in
lora_a_group
]),
q_a
.
shape
[
1
]]
lora_a
=
torch
.
zeros
(
new_shape_a
,
dtype
=
q_a
.
dtype
,
device
=
q_a
.
device
)
start_dim
=
0
for
tensor
in
lora_a_group
:
lora_a
[
start_dim
:
start_dim
+
tensor
.
shape
[
0
]]
=
tensor
start_dim
+=
tensor
.
shape
[
0
]
lora_b_group
=
(
q_b
,
k_b
,
v_b
)
new_shape_b
=
[
sum
([
_
.
shape
[
0
]
for
_
in
lora_b_group
]),
sum
([
_
.
shape
[
1
]
for
_
in
lora_b_group
])]
lora_b
=
torch
.
zeros
(
new_shape_b
,
dtype
=
q_b
.
dtype
,
device
=
q_b
.
device
)
start_dims
=
(
0
,
0
)
for
tensor
in
lora_b_group
:
end_dims
=
(
start_dims
[
0
]
+
tensor
.
shape
[
0
],
start_dims
[
1
]
+
tensor
.
shape
[
1
])
lora_b
[
start_dims
[
0
]
:
end_dims
[
0
],
start_dims
[
1
]
:
end_dims
[
1
]]
=
tensor
start_dims
=
end_dims
lora_a
=
lora_a
*
strength
new_k_a
=
k
.
replace
(
".to_q."
,
".to_qkv."
).
replace
(
".add_q_proj."
,
".add_qkv_proj."
)
new_k_b
=
new_k_a
.
replace
(
"lora_A"
,
"lora_B"
)
for
kk
,
vv
,
dim
in
((
new_k_a
,
lora_a
,
0
),
(
new_k_b
,
lora_b
,
1
)):
previous_lora
=
composed
.
get
(
kk
,
None
)
composed
[
kk
]
=
vv
if
previous_lora
is
None
else
torch
.
cat
([
previous_lora
,
vv
],
dim
=
dim
)
elif
".to_k."
in
k
or
".to_v."
in
k
or
".add_k_proj."
in
k
or
".add_v_proj."
in
k
:
continue
else
:
if
"lora_A"
in
k
:
v
=
v
*
strength
previous_lora
=
composed
.
get
(
k
,
None
)
if
previous_lora
is
None
:
composed
[
k
]
=
v
else
:
if
"lora_A"
in
k
:
if
previous_lora
.
shape
[
1
]
!=
v
.
shape
[
1
]:
# flux.1-tools LoRA compatibility
assert
"x_embedder"
in
k
expanded_size
=
max
(
previous_lora
.
shape
[
1
],
v
.
shape
[
1
])
if
expanded_size
>
previous_lora
.
shape
[
1
]:
expanded_previous_lora
=
torch
.
zeros
(
(
previous_lora
.
shape
[
0
],
expanded_size
),
device
=
previous_lora
.
device
,
dtype
=
previous_lora
.
dtype
,
)
expanded_previous_lora
[:,
:
previous_lora
.
shape
[
1
]]
=
previous_lora
else
:
expanded_previous_lora
=
previous_lora
if
expanded_size
>
v
.
shape
[
1
]:
expanded_v
=
torch
.
zeros
(
(
v
.
shape
[
0
],
expanded_size
),
device
=
v
.
device
,
dtype
=
v
.
dtype
)
expanded_v
[:,
:
v
.
shape
[
1
]]
=
v
else
:
expanded_v
=
v
composed
[
k
]
=
torch
.
cat
([
expanded_previous_lora
,
expanded_v
],
dim
=
0
)
else
:
composed
[
k
]
=
torch
.
cat
([
previous_lora
,
v
],
dim
=
0
)
else
:
composed
[
k
]
=
torch
.
cat
([
previous_lora
,
v
],
dim
=
1
)
composed
[
k
]
=
(
v
if
previous_lora
is
None
else
torch
.
cat
([
previous_lora
,
v
],
dim
=
0
if
"lora_A"
in
k
else
1
)
)
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
composed
,
output_path
)
return
composed
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-paths"
,
type
=
str
,
nargs
=
"*"
,
required
=
True
,
help
=
"paths to the lora safetensors files"
)
parser
.
add_argument
(
"-s"
,
"--strengths"
,
type
=
float
,
nargs
=
"*"
,
required
=
True
,
help
=
"strengths for each lora"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output safetensors file"
)
args
=
parser
.
parse_args
()
assert
len
(
args
.
input_paths
)
==
len
(
args
.
strengths
)
composed
=
compose_lora
(
list
(
zip
(
args
.
input_paths
,
args
.
strengths
)))
nunchaku/lora/flux/convert.py
View file @
3ef186fd
import
argparse
import
argparse
import
os
import
os
import
torch
from
.nunchaku_converter
import
to_nunchaku
from
safetensors.torch
import
save_file
from
.utils
import
is_nunchaku_format
from
.comfyui_converter
import
comfyui2diffusers
from
.diffusers_converter
import
convert_to_nunchaku_flux_lowrank_dict
from
.utils
import
detect_format
from
.xlab_converter
import
xlab2diffusers
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -19,13 +13,6 @@ if __name__ == "__main__":
...
@@ -19,13 +13,6 @@ if __name__ == "__main__":
default
=
"mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors"
,
default
=
"mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors"
,
)
)
parser
.
add_argument
(
"--lora-path"
,
type
=
str
,
required
=
True
,
help
=
"path to LoRA weights safetensor file"
)
parser
.
add_argument
(
"--lora-path"
,
type
=
str
,
required
=
True
,
help
=
"path to LoRA weights safetensor file"
)
parser
.
add_argument
(
"--lora-format"
,
type
=
str
,
default
=
"auto"
,
choices
=
[
"auto"
,
"comfyui"
,
"diffusers"
,
"xlab"
],
help
=
"format of the LoRA weights"
,
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
parser
.
add_argument
(
"--lora-name"
,
type
=
str
,
default
=
None
,
help
=
"name of the LoRA weights"
)
parser
.
add_argument
(
"--lora-name"
,
type
=
str
,
default
=
None
,
help
=
"name of the LoRA weights"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -37,46 +24,26 @@ if __name__ == "__main__":
...
@@ -37,46 +24,26 @@ if __name__ == "__main__":
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
is_nunchaku_format
(
args
.
lora_path
):
print
(
"Already in nunchaku format, no conversion needed."
)
exit
(
0
)
if
not
args
.
output_root
:
if
not
args
.
output_root
:
# output to the parent directory of the lora safetensor file
# output to the parent directory of the lora safetensor
s
file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora_path
)
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora_path
)
if
args
.
lora_name
is
None
:
if
args
.
lora_name
is
None
:
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
lora_name
=
"svdq-int4-"
+
lora_name
precision
=
"fp4"
if
"fp4"
in
args
.
quant_path
else
"int4"
lora_name
=
f
"svdq-
{
precision
}
-
{
lora_name
}
"
print
(
f
"LoRA name not provided, using
{
lora_name
}
as the LoRA name"
)
print
(
f
"LoRA name not provided, using
{
lora_name
}
as the LoRA name"
)
else
:
else
:
lora_name
=
args
.
lora_name
lora_name
=
args
.
lora_name
assert
lora_name
,
"LoRA name must be provided."
assert
lora_name
,
"LoRA name must be provided."
assert
args
.
quant_path
.
endswith
(
".safetensors"
),
"Quantized model must be a safetensor file"
to_nunchaku
(
assert
args
.
lora_path
.
endswith
(
".safetensors"
),
"LoRA weights must be a safetensor file"
args
.
lora_path
,
orig_state_dict
=
load_state_dict_in_safetensors
(
args
.
quant_path
)
args
.
quant_path
,
lora_format
=
args
.
lora_format
dtype
=
args
.
dtype
,
output_path
=
os
.
path
.
join
(
args
.
output_root
,
f
"
{
lora_name
}
.safetensors"
),
if
lora_format
==
"auto"
:
lora_format
=
detect_format
(
args
.
lora_path
)
print
(
f
"Detected LoRA format:
{
lora_format
}
"
)
if
lora_format
==
"svdquant"
:
print
(
"Already in SVDQuant format, no conversion needed."
)
exit
(
0
)
if
lora_format
==
"diffusers"
:
extra_lora_dict
=
load_state_dict_in_safetensors
(
args
.
lora_path
)
else
:
if
lora_format
==
"comfyui"
:
extra_lora_dict
=
comfyui2diffusers
(
args
.
lora_path
)
elif
lora_format
==
"xlab"
:
extra_lora_dict
=
xlab2diffusers
(
args
.
lora_path
)
else
:
raise
NotImplementedError
(
f
"LoRA format
{
lora_format
}
is not supported."
)
extra_lora_dict
=
filter_state_dict
(
extra_lora_dict
)
converted
=
convert_to_nunchaku_flux_lowrank_dict
(
base_model
=
orig_state_dict
,
lora
=
extra_lora_dict
,
default_dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"bfloat16"
else
torch
.
float16
,
)
)
os
.
makedirs
(
args
.
output_root
,
exist_ok
=
True
)
save_file
(
converted
,
os
.
path
.
join
(
args
.
output_root
,
f
"
{
lora_name
}
.safetensors"
))
print
(
f
"Saved LoRA weights to
{
args
.
output_root
}
."
)
nunchaku/lora/flux/diffusers_converter.py
View file @
3ef186fd
# convert the diffusers lora to nunchaku format
import
argparse
"""Convert LoRA weights to Nunchaku format."""
import
os
import
typing
as
tp
import
warnings
import
torch
import
torch
import
tqdm
from
diffusers.loaders
import
FluxLoraLoaderMixin
from
safetensors.torch
import
save_file
from
..
.utils
import
ceil_divide
,
filter_state_dict
,
load_state_dict_in_safetensors
from
.utils
import
load_state_dict_in_safetensors
# region utilities
def
to_diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
def
pad
(
tensor
:
tp
.
Optional
[
torch
.
Tensor
],
divisor
:
int
|
tp
.
Sequence
[
int
],
dim
:
int
|
tp
.
Sequence
[
int
],
fill_value
:
float
|
int
=
0
,
)
->
torch
.
Tensor
|
None
:
if
isinstance
(
divisor
,
int
):
if
divisor
<=
1
:
return
tensor
elif
all
(
d
<=
1
for
d
in
divisor
):
return
tensor
if
tensor
is
None
:
return
None
shape
=
list
(
tensor
.
shape
)
if
isinstance
(
dim
,
int
):
assert
isinstance
(
divisor
,
int
)
shape
[
dim
]
=
ceil_divide
(
shape
[
dim
],
divisor
)
*
divisor
else
:
if
isinstance
(
divisor
,
int
):
divisor
=
[
divisor
]
*
len
(
dim
)
for
d
,
div
in
zip
(
dim
,
divisor
,
strict
=
True
):
shape
[
d
]
=
ceil_divide
(
shape
[
d
],
div
)
*
div
result
=
torch
.
full
(
shape
,
fill_value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
return
result
def
update_state_dict
(
lhs
:
dict
[
str
,
torch
.
Tensor
],
rhs
:
dict
[
str
,
torch
.
Tensor
],
prefix
:
str
=
""
)
->
dict
[
str
,
torch
.
Tensor
]:
for
rkey
,
value
in
rhs
.
items
():
lkey
=
f
"
{
prefix
}
.
{
rkey
}
"
if
prefix
else
rkey
assert
lkey
not
in
lhs
,
f
"Key
{
lkey
}
already exists in the state dict."
lhs
[
lkey
]
=
value
return
lhs
# endregion
def
pack_lowrank_weight
(
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
lane_n
,
lane_k
=
1
,
2
# lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size
,
k_pack_size
=
2
,
2
num_n_lanes
,
num_k_lanes
=
8
,
4
frag_n
=
n_pack_size
*
num_n_lanes
*
lane_n
frag_k
=
k_pack_size
*
num_k_lanes
*
lane_k
weight
=
pad
(
weight
,
divisor
=
(
frag_n
,
frag_k
),
dim
=
(
0
,
1
))
if
down
:
r
,
c
=
weight
.
shape
r_frags
,
c_frags
=
r
//
frag_n
,
c
//
frag_k
weight
=
weight
.
view
(
r_frags
,
frag_n
,
c_frags
,
frag_k
).
permute
(
2
,
0
,
1
,
3
)
else
:
c
,
r
=
weight
.
shape
c_frags
,
r_frags
=
c
//
frag_n
,
r
//
frag_k
weight
=
weight
.
view
(
c_frags
,
frag_n
,
r_frags
,
frag_k
).
permute
(
0
,
2
,
1
,
3
)
weight
=
weight
.
reshape
(
c_frags
,
r_frags
,
n_pack_size
,
num_n_lanes
,
k_pack_size
,
num_k_lanes
,
lane_k
)
weight
=
weight
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
,
6
).
contiguous
()
return
weight
.
view
(
c
,
r
)
def
unpack_lowrank_weight
(
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c
,
r
=
weight
.
shape
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
lane_n
,
lane_k
=
1
,
2
# lane_n is always 1, lane_k is 32 bits // 16 bits = 2
n_pack_size
,
k_pack_size
=
2
,
2
num_n_lanes
,
num_k_lanes
=
8
,
4
frag_n
=
n_pack_size
*
num_n_lanes
*
lane_n
frag_k
=
k_pack_size
*
num_k_lanes
*
lane_k
if
down
:
r_frags
,
c_frags
=
r
//
frag_n
,
c
//
frag_k
else
:
else
:
c_frags
,
r_frags
=
c
//
frag_n
,
r
//
frag_k
tensors
=
input_lora
weight
=
weight
.
view
(
c_frags
,
r_frags
,
num_n_lanes
,
num_k_lanes
,
n_pack_size
,
k_pack_size
,
lane_k
)
new_tensors
,
alphas
=
FluxLoraLoaderMixin
.
lora_state_dict
(
tensors
,
return_alphas
=
True
)
weight
=
weight
.
permute
(
0
,
1
,
4
,
2
,
5
,
3
,
6
).
contiguous
()
weight
=
weight
.
view
(
c_frags
,
r_frags
,
frag_n
,
frag_k
)
if
down
:
weight
=
weight
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
r
,
c
)
else
:
weight
=
weight
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
c
,
r
)
return
weight
def
reorder_adanorm_lora_up
(
lora_up
:
torch
.
Tensor
,
splits
:
int
)
->
torch
.
Tensor
:
if
alphas
is
not
None
and
len
(
alphas
)
>
0
:
c
,
r
=
lora_up
.
shape
warnings
.
warn
(
"Alpha values are not used in the conversion to diffusers format."
)
assert
c
%
splits
==
0
return
lora_up
.
view
(
splits
,
c
//
splits
,
r
).
transpose
(
0
,
1
).
reshape
(
c
,
r
).
contiguous
()
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
def
convert_to_nunchaku_transformer_block_lowrank_dict
(
# noqa: C901
return
new_tensors
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
converted_block_name
:
str
,
candidate_block_name
:
str
,
local_name_map
:
dict
[
str
,
str
|
list
[
str
]],
convert_map
:
dict
[
str
,
str
],
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
print
(
f
"Converting LoRA branch for block
{
candidate_block_name
}
..."
)
converted
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
converted_local_name
,
candidate_local_names
in
tqdm
.
tqdm
(
local_name_map
.
items
(),
desc
=
f
"Converting
{
candidate_block_name
}
"
,
dynamic_ncols
=
True
):
if
isinstance
(
candidate_local_names
,
str
):
candidate_local_names
=
[
candidate_local_names
]
# region original LoRA
orig_lora
=
(
orig_state_dict
.
get
(
f
"
{
converted_block_name
}
.
{
converted_local_name
}
.lora_down"
,
None
),
orig_state_dict
.
get
(
f
"
{
converted_block_name
}
.
{
converted_local_name
}
.lora_up"
,
None
),
)
if
orig_lora
[
0
]
is
None
or
orig_lora
[
1
]
is
None
:
assert
orig_lora
[
0
]
is
None
and
orig_lora
[
1
]
is
None
orig_lora
=
None
else
:
assert
orig_lora
[
0
]
is
not
None
and
orig_lora
[
1
]
is
not
None
orig_lora
=
(
unpack_lowrank_weight
(
orig_lora
[
0
],
down
=
True
),
unpack_lowrank_weight
(
orig_lora
[
1
],
down
=
False
),
)
print
(
f
" - Found
{
converted_block_name
}
LoRA of
{
converted_local_name
}
(rank:
{
orig_lora
[
0
].
shape
[
0
]
}
)"
)
# endregion
# region extra LoRA
extra_lora
=
[
(
extra_lora_dict
.
get
(
f
"
{
candidate_block_name
}
.
{
candidate_local_name
}
.lora_A.weight"
,
None
),
extra_lora_dict
.
get
(
f
"
{
candidate_block_name
}
.
{
candidate_local_name
}
.lora_B.weight"
,
None
),
)
for
candidate_local_name
in
candidate_local_names
]
if
any
(
lora
[
0
]
is
not
None
or
lora
[
1
]
is
not
None
for
lora
in
extra_lora
):
# merge extra LoRAs into one LoRA
if
len
(
extra_lora
)
>
1
:
first_lora
=
None
for
lora
in
extra_lora
:
if
lora
[
0
]
is
not
None
:
assert
lora
[
1
]
is
not
None
first_lora
=
lora
break
assert
first_lora
is
not
None
for
lora_index
in
range
(
len
(
extra_lora
)):
if
extra_lora
[
lora_index
][
0
]
is
None
:
assert
extra_lora
[
lora_index
][
1
]
is
None
extra_lora
[
lora_index
]
=
(
first_lora
[
0
].
clone
(),
torch
.
zeros_like
(
first_lora
[
1
]))
if
all
(
lora
[
0
].
equal
(
extra_lora
[
0
][
0
])
for
lora
in
extra_lora
):
# if all extra LoRAs have the same lora_down, use it
extra_lora_down
=
extra_lora
[
0
][
0
]
extra_lora_up
=
torch
.
cat
([
lora
[
1
]
for
lora
in
extra_lora
],
dim
=
0
)
else
:
extra_lora_down
=
torch
.
cat
([
lora
[
0
]
for
lora
in
extra_lora
],
dim
=
0
)
extra_lora_up_c
=
sum
(
lora
[
1
].
shape
[
0
]
for
lora
in
extra_lora
)
extra_lora_up_r
=
sum
(
lora
[
1
].
shape
[
1
]
for
lora
in
extra_lora
)
assert
extra_lora_up_r
==
extra_lora_down
.
shape
[
0
]
extra_lora_up
=
torch
.
zeros
((
extra_lora_up_c
,
extra_lora_up_r
),
dtype
=
extra_lora_down
.
dtype
)
c
,
r
=
0
,
0
for
lora
in
extra_lora
:
c_next
,
r_next
=
c
+
lora
[
1
].
shape
[
0
],
r
+
lora
[
1
].
shape
[
1
]
extra_lora_up
[
c
:
c_next
,
r
:
r_next
]
=
lora
[
1
]
c
,
r
=
c_next
,
r_next
else
:
extra_lora_down
,
extra_lora_up
=
extra_lora
[
0
]
extra_lora
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
=
(
extra_lora_down
,
extra_lora_up
)
print
(
f
" - Found
{
candidate_block_name
}
LoRA of
{
candidate_local_names
}
(rank:
{
extra_lora
[
0
].
shape
[
0
]
}
)"
)
else
:
extra_lora
=
None
# endregion
# region merge LoRA
if
orig_lora
is
None
:
if
extra_lora
is
None
:
lora
=
None
else
:
print
(
" - Using extra LoRA"
)
lora
=
(
extra_lora
[
0
].
to
(
default_dtype
),
extra_lora
[
1
].
to
(
default_dtype
))
elif
extra_lora
is
None
:
print
(
" - Using original LoRA"
)
lora
=
orig_lora
else
:
lora
=
(
torch
.
cat
([
orig_lora
[
0
],
extra_lora
[
0
].
to
(
orig_lora
[
0
].
dtype
)],
dim
=
0
),
torch
.
cat
([
orig_lora
[
1
],
extra_lora
[
1
].
to
(
orig_lora
[
1
].
dtype
)],
dim
=
1
),
)
print
(
f
" - Merging original and extra LoRA (rank:
{
lora
[
0
].
shape
[
0
]
}
)"
)
# endregion
if
lora
is
not
None
:
if
convert_map
[
converted_local_name
]
==
"adanorm_single"
:
update_state_dict
(
converted
,
{
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
elif
convert_map
[
converted_local_name
]
==
"adanorm_zero"
:
update_state_dict
(
converted
,
{
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
elif
convert_map
[
converted_local_name
]
==
"linear"
:
update_state_dict
(
converted
,
{
"lora_down"
:
pack_lowrank_weight
(
lora
[
0
],
down
=
True
),
"lora_up"
:
pack_lowrank_weight
(
lora
[
1
],
down
=
False
),
},
prefix
=
converted_local_name
,
)
return
converted
def
convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
(
if
__name__
==
"__main__"
:
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
parser
=
argparse
.
ArgumentParser
()
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensors file"
)
converted_block_name
:
str
,
parser
.
add_argument
(
candidate_block_name
:
str
,
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensors file"
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
if
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
in
extra_lora_dict
:
assert
f
"
{
converted_block_name
}
.out_proj.qweight"
in
orig_state_dict
assert
f
"
{
converted_block_name
}
.mlp_fc2.qweight"
in
orig_state_dict
n1
=
orig_state_dict
[
f
"
{
converted_block_name
}
.out_proj.qweight"
].
shape
[
1
]
*
2
n2
=
orig_state_dict
[
f
"
{
converted_block_name
}
.mlp_fc2.qweight"
].
shape
[
1
]
*
2
lora_down
=
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
]
lora_up
=
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
]
assert
lora_down
.
shape
[
1
]
==
n1
+
n2
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.0.lora_A.weight"
]
=
lora_down
[:,
:
n1
].
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.0.lora_B.weight"
]
=
lora_up
.
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.1.lora_A.weight"
]
=
lora_down
[:,
n1
:].
clone
()
extra_lora_dict
[
f
"
{
candidate_block_name
}
.proj_out.linears.1.lora_B.weight"
]
=
lora_up
.
clone
()
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
for
component
in
[
"lora_A"
,
"lora_B"
]:
fc1_k
=
f
"
{
candidate_block_name
}
.proj_mlp.
{
component
}
.weight"
fc2_k
=
f
"
{
candidate_block_name
}
.proj_out.linears.1.
{
component
}
.weight"
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
converted_block_name
,
candidate_block_name
=
candidate_block_name
,
local_name_map
=
{
"norm.linear"
:
"norm.linear"
,
"qkv_proj"
:
[
"attn.to_q"
,
"attn.to_k"
,
"attn.to_v"
],
"norm_q"
:
"attn.norm_q"
,
"norm_k"
:
"attn.norm_k"
,
"out_proj"
:
"proj_out.linears.0"
,
"mlp_fc1"
:
"proj_mlp"
,
"mlp_fc2"
:
"proj_out.linears.1"
,
},
convert_map
=
{
"norm.linear"
:
"adanorm_single"
,
"qkv_proj"
:
"linear"
,
"out_proj"
:
"linear"
,
"mlp_fc1"
:
"linear"
,
"mlp_fc2"
:
"linear"
,
},
default_dtype
=
default_dtype
,
)
def
convert_to_nunchaku_flux_transformer_block_lowrank_dict
(
orig_state_dict
:
dict
[
str
,
torch
.
Tensor
],
extra_lora_dict
:
dict
[
str
,
torch
.
Tensor
],
converted_block_name
:
str
,
candidate_block_name
:
str
,
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
converted_block_name
,
candidate_block_name
=
candidate_block_name
,
local_name_map
=
{
"norm1.linear"
:
"norm1.linear"
,
"norm1_context.linear"
:
"norm1_context.linear"
,
"qkv_proj"
:
[
"attn.to_q"
,
"attn.to_k"
,
"attn.to_v"
],
"qkv_proj_context"
:
[
"attn.add_q_proj"
,
"attn.add_k_proj"
,
"attn.add_v_proj"
],
"norm_q"
:
"attn.norm_q"
,
"norm_k"
:
"attn.norm_k"
,
"norm_added_q"
:
"attn.norm_added_q"
,
"norm_added_k"
:
"attn.norm_added_k"
,
"out_proj"
:
"attn.to_out.0"
,
"out_proj_context"
:
"attn.to_add_out"
,
"mlp_fc1"
:
"ff.net.0.proj"
,
"mlp_fc2"
:
"ff.net.2"
,
"mlp_context_fc1"
:
"ff_context.net.0.proj"
,
"mlp_context_fc2"
:
"ff_context.net.2"
,
},
convert_map
=
{
"norm1.linear"
:
"adanorm_zero"
,
"norm1_context.linear"
:
"adanorm_zero"
,
"qkv_proj"
:
"linear"
,
"qkv_proj_context"
:
"linear"
,
"out_proj"
:
"linear"
,
"out_proj_context"
:
"linear"
,
"mlp_fc1"
:
"linear"
,
"mlp_fc2"
:
"linear"
,
"mlp_context_fc1"
:
"linear"
,
"mlp_context_fc2"
:
"linear"
,
},
default_dtype
=
default_dtype
,
)
)
args
=
parser
.
parse_args
()
to_diffusers
(
args
.
input_path
,
args
.
output_path
)
def
convert_to_nunchaku_flux_lowrank_dict
(
base_model
:
dict
[
str
,
torch
.
Tensor
]
|
str
,
lora
:
dict
[
str
,
torch
.
Tensor
]
|
str
,
default_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
base_model
,
str
):
orig_state_dict
=
load_state_dict_in_safetensors
(
base_model
)
else
:
orig_state_dict
=
base_model
if
isinstance
(
lora
,
str
):
extra_lora_dict
=
load_state_dict_in_safetensors
(
lora
,
filter_prefix
=
"transformer."
)
else
:
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
unquantized_lora_dict
=
{}
for
k
in
list
(
extra_lora_dict
.
keys
()):
if
"transformer_blocks"
not
in
k
:
unquantized_lora_dict
[
k
]
=
extra_lora_dict
.
pop
(
k
)
for
k
in
extra_lora_dict
.
keys
():
fc1_k
=
k
if
"ff.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff.net.0.proj"
,
"ff.net.2"
)
elif
"ff_context.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff_context.net.0.proj"
,
"ff_context.net.2"
)
else
:
continue
assert
fc2_k
in
extra_lora_dict
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
block_names
.
add
(
"."
.
join
(
param_name
.
split
(
"."
)[:
2
]))
block_names
=
sorted
(
block_names
,
key
=
lambda
x
:
(
x
.
split
(
"."
)[
0
],
int
(
x
.
split
(
"."
)[
-
1
])))
print
(
f
"Converting
{
len
(
block_names
)
}
transformer blocks..."
)
converted
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
block_name
in
block_names
:
if
block_name
.
startswith
(
"transformer_blocks"
):
convert_fn
=
convert_to_nunchaku_flux_transformer_block_lowrank_dict
else
:
convert_fn
=
convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
update_state_dict
(
converted
,
convert_fn
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
converted_block_name
=
block_name
,
candidate_block_name
=
block_name
,
default_dtype
=
default_dtype
,
),
prefix
=
block_name
,
)
converted
.
update
(
unquantized_lora_dict
)
return
converted
nunchaku/lora/flux/nunchaku_converter.py
0 → 100644
View file @
3ef186fd
This diff is collapsed.
Click to expand it.
nunchaku/lora/flux/packer.py
0 → 100644
View file @
3ef186fd
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
import
torch
from
.utils
import
pad
from
...utils
import
ceil_divide
class
MmaWeightPackerBase
:
def
__init__
(
self
,
bits
:
int
,
warp_n
:
int
,
comp_n
:
int
=
None
,
comp_k
:
int
=
None
):
self
.
bits
=
bits
assert
self
.
bits
in
(
1
,
4
,
8
,
16
,
32
),
"weight bits should be 1, 4, 8, 16, or 32."
# region compute tile size
self
.
comp_n
=
comp_n
if
comp_n
is
not
None
else
16
"""smallest tile size in `n` dimension for MMA computation."""
self
.
comp_k
=
comp_k
if
comp_k
is
not
None
else
256
//
self
.
bits
"""smallest tile size in `k` dimension for MMA computation."""
# the smallest MMA computation may contain several MMA instructions
self
.
insn_n
=
8
# mma instruction tile size in `n` dimension
"""tile size in `n` dimension for MMA instruction."""
self
.
insn_k
=
self
.
comp_k
"""tile size in `k` dimension for MMA instruction."""
assert
self
.
insn_k
*
self
.
bits
in
(
128
,
256
,
),
f
"insn_k (
{
self
.
insn_k
}
) * bits (
{
self
.
bits
}
) should be 128 or 256."
assert
self
.
comp_n
%
self
.
insn_n
==
0
,
f
"comp_n (
{
self
.
comp_n
}
) should be divisible by insn_n (
{
self
.
insn_n
}
)."
self
.
num_lanes
=
32
"""there are 32 lanes (or threds) in a warp."""
self
.
num_k_lanes
=
4
self
.
num_n_lanes
=
8
assert
(
warp_n
>=
self
.
comp_n
and
warp_n
%
self
.
comp_n
==
0
),
f
"warp_n (
{
warp_n
}
) should be divisible by comp_n(
{
self
.
comp_n
}
)."
self
.
warp_n
=
warp_n
# endregion
# region memory
self
.
reg_k
=
32
//
self
.
bits
"""number of elements in a register in `k` dimension."""
self
.
reg_n
=
1
"""number of elements in a register in `n` dimension (always 1)."""
self
.
k_pack_size
=
self
.
comp_k
//
(
self
.
num_k_lanes
*
self
.
reg_k
)
"""number of elements in a pack in `k` dimension."""
self
.
n_pack_size
=
self
.
comp_n
//
(
self
.
num_n_lanes
*
self
.
reg_n
)
"""number of elements in a pack in `n` dimension."""
self
.
pack_size
=
self
.
k_pack_size
*
self
.
n_pack_size
"""number of elements in a pack accessed by a lane at a time."""
assert
1
<=
self
.
pack_size
<=
4
,
"pack size should be less than or equal to 4."
assert
self
.
k_pack_size
*
self
.
num_k_lanes
*
self
.
reg_k
==
self
.
comp_k
assert
self
.
n_pack_size
*
self
.
num_n_lanes
*
self
.
reg_n
==
self
.
comp_n
self
.
mem_k
=
self
.
comp_k
"""the tile size in `k` dimension for one tensor memory access."""
self
.
mem_n
=
warp_n
"""the tile size in `n` dimension for one tensor memory access."""
self
.
num_k_packs
=
self
.
mem_k
//
(
self
.
k_pack_size
*
self
.
num_k_lanes
*
self
.
reg_k
)
"""number of packs in `k` dimension for one tensor memory access."""
self
.
num_n_packs
=
self
.
mem_n
//
(
self
.
n_pack_size
*
self
.
num_n_lanes
*
self
.
reg_n
)
"""number of packs in `n` dimension for one tensor memory access."""
# endregion
def
get_view_shape
(
self
,
n
:
int
,
k
:
int
)
->
tuple
[
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
,
int
]:
assert
n
%
self
.
mem_n
==
0
,
"output channel size should be divisible by mem_n."
assert
k
%
self
.
mem_k
==
0
,
"input channel size should be divisible by mem_k."
return
(
n
//
self
.
mem_n
,
self
.
num_n_packs
,
self
.
n_pack_size
,
self
.
num_n_lanes
,
self
.
reg_n
,
k
//
self
.
mem_k
,
self
.
num_k_packs
,
self
.
k_pack_size
,
self
.
num_k_lanes
,
self
.
reg_k
,
)
class
NunchakuWeightPacker
(
MmaWeightPackerBase
):
def
__init__
(
self
,
bits
:
int
,
warp_n
:
int
=
128
):
super
().
__init__
(
bits
=
bits
,
warp_n
=
warp_n
)
self
.
num_k_unrolls
=
2
def
pack_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
weight
.
dtype
==
torch
.
int32
,
f
"quantized weight should be torch.int32, but got
{
weight
.
dtype
}
."
n
,
k
=
weight
.
shape
assert
n
%
self
.
mem_n
==
0
,
f
"output channel size (
{
n
}
) should be divisible by mem_n (
{
self
.
mem_n
}
)."
# currently, Nunchaku did not check the boundry of unrolled `k` dimension
assert
k
%
(
self
.
mem_k
*
self
.
num_k_unrolls
)
==
0
,
(
f
"input channel size (
{
k
}
) should be divisible by "
f
"mem_k (
{
self
.
mem_k
}
) * num_k_unrolls (
{
self
.
num_k_unrolls
}
)."
)
n_tiles
,
k_tiles
=
n
//
self
.
mem_n
,
k
//
self
.
mem_k
weight
=
weight
.
reshape
(
n_tiles
,
self
.
num_n_packs
,
# 8 when warp_n = 128
self
.
n_pack_size
,
# always 2 in nunchaku
self
.
num_n_lanes
,
# constant 8
self
.
reg_n
,
# constant 1
k_tiles
,
self
.
num_k_packs
,
# 1
self
.
k_pack_size
,
# always 2 in nunchaku
self
.
num_k_lanes
,
# constant 4
self
.
reg_k
,
# always 8 = 32 bits / 4 bits
)
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
# =>
# (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight
=
weight
.
permute
(
0
,
5
,
6
,
1
,
3
,
8
,
2
,
7
,
4
,
9
).
contiguous
()
assert
weight
.
shape
[
4
:
-
2
]
==
(
8
,
4
,
2
,
2
)
if
self
.
bits
==
4
:
weight
=
weight
.
bitwise_and_
(
0xF
)
shift
=
torch
.
arange
(
0
,
32
,
4
,
dtype
=
torch
.
int32
,
device
=
weight
.
device
)
weight
=
weight
.
bitwise_left_shift_
(
shift
)
weight
=
weight
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
elif
self
.
bits
==
8
:
weight
=
weight
.
bitwise_and_
(
0xFF
)
shift
=
torch
.
arange
(
0
,
32
,
8
,
dtype
=
torch
.
int32
,
device
=
weight
.
device
)
weight
=
weight
.
bitwise_left_shift_
(
shift
)
weight
=
weight
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
else
:
raise
NotImplementedError
(
f
"weight bits
{
self
.
bits
}
is not supported."
)
return
weight
.
view
(
dtype
=
torch
.
int8
).
view
(
n
,
-
1
)
# assume little-endian
def
pack_scale
(
self
,
scale
:
torch
.
Tensor
,
group_size
:
int
)
->
torch
.
Tensor
:
if
self
.
check_if_micro_scale
(
group_size
=
group_size
):
return
self
.
pack_micro_scale
(
scale
,
group_size
=
group_size
)
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
assert
scale
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
"currently nunchaku only supports fp16 and bf16."
n
=
scale
.
shape
[
0
]
# nunchaku load scales all in one access
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 16 bit as it contains 1 fp16
# min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
# max `s_pack_size` set to 128b/16b = 8 elements
# for `warp_n = 8`, we have
# `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size
=
min
(
max
(
self
.
warp_n
//
self
.
num_lanes
,
2
),
8
)
num_s_lanes
=
min
(
self
.
num_lanes
,
self
.
warp_n
//
s_pack_size
)
num_s_packs
=
self
.
warp_n
//
(
s_pack_size
*
num_s_lanes
)
warp_s
=
num_s_packs
*
num_s_lanes
*
s_pack_size
assert
warp_s
==
self
.
warp_n
,
"warp_n for scales should be equal to warp_n for weights."
# `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
# however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
# thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
# we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
# finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
# the process is repeated for `num_s_packs` times
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
# 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
# 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
# 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
# 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
# ... ...
# 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
scale
=
scale
.
reshape
(
n
//
warp_s
,
num_s_packs
,
num_s_lanes
//
4
,
s_pack_size
//
2
,
4
,
2
,
-
1
)
scale
=
scale
.
permute
(
0
,
6
,
1
,
2
,
4
,
3
,
5
).
contiguous
()
return
scale
.
view
(
-
1
)
if
group_size
==
-
1
else
scale
.
view
(
-
1
,
n
)
# the shape is just used for validation
def
pack_micro_scale
(
self
,
scale
:
torch
.
Tensor
,
group_size
:
int
)
->
torch
.
Tensor
:
assert
scale
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
"currently nunchaku only supports fp16 and bf16."
assert
scale
.
max
()
<=
448
,
"scale should be less than 448."
assert
scale
.
min
()
>=
-
448
,
"scale should be greater than -448."
assert
group_size
==
16
,
"currently only support group size 16."
assert
self
.
insn_k
==
64
,
"insn_k should be 64."
scale
=
scale
.
to
(
dtype
=
torch
.
float8_e4m3fn
)
n
=
scale
.
shape
[
0
]
assert
self
.
warp_n
>=
32
,
"currently only support warp_n >= 32."
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
# min `s_pack_size` set to 1 element
# max `s_pack_size` set to 128b/32b = 4 elements
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size
=
min
(
max
(
self
.
warp_n
//
self
.
num_lanes
,
1
),
4
)
num_s_lanes
=
4
*
8
# 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
num_s_packs
=
ceil_divide
(
self
.
warp_n
,
s_pack_size
*
num_s_lanes
)
warp_s
=
num_s_packs
*
num_s_lanes
*
s_pack_size
assert
warp_s
==
self
.
warp_n
,
"warp_n for scales should be equal to warp_n for weights."
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
# we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
# we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 32 64 96 <-- load by lane 0
# 8 40 72 104 <-- load by lane 1
# 16 48 80 112 <-- load by lane 2
# 24 56 88 120 <-- load by lane 3
# 1 33 65 97 <-- load by lane 4
# ...
# 25 57 81 113 <-- load by lane 7
# ...
# 7 39 71 103 <-- load by lane 28
# ...
# 31 63 95 127 <-- load by lane 31
scale
=
scale
.
view
(
n
//
warp_s
,
num_s_packs
,
s_pack_size
,
4
,
8
,
-
1
,
self
.
insn_k
//
group_size
)
scale
=
scale
.
permute
(
0
,
5
,
1
,
4
,
3
,
2
,
6
).
contiguous
()
return
scale
.
view
(
-
1
,
n
)
# the shape is just used for validation
def
pack_lowrank_weight
(
self
,
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
reg_n
,
reg_k
=
1
,
2
# reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n
=
self
.
n_pack_size
*
self
.
num_n_lanes
*
reg_n
pack_k
=
self
.
k_pack_size
*
self
.
num_k_lanes
*
reg_k
weight
=
pad
(
weight
,
divisor
=
(
pack_n
,
pack_k
),
dim
=
(
0
,
1
))
if
down
:
r
,
c
=
weight
.
shape
r_packs
,
c_packs
=
r
//
pack_n
,
c
//
pack_k
weight
=
weight
.
view
(
r_packs
,
pack_n
,
c_packs
,
pack_k
).
permute
(
2
,
0
,
1
,
3
)
else
:
c
,
r
=
weight
.
shape
c_packs
,
r_packs
=
c
//
pack_n
,
r
//
pack_k
weight
=
weight
.
view
(
c_packs
,
pack_n
,
r_packs
,
pack_k
).
permute
(
0
,
2
,
1
,
3
)
weight
=
weight
.
reshape
(
c_packs
,
r_packs
,
self
.
n_pack_size
,
self
.
num_n_lanes
,
reg_n
,
self
.
k_pack_size
,
self
.
num_k_lanes
,
reg_k
)
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
# =>
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight
=
weight
.
permute
(
0
,
1
,
3
,
6
,
2
,
5
,
4
,
7
).
contiguous
()
return
weight
.
view
(
c
,
r
)
def
unpack_lowrank_weight
(
self
,
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c
,
r
=
weight
.
shape
assert
weight
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
f
"Unsupported weight dtype
{
weight
.
dtype
}
."
reg_n
,
reg_k
=
1
,
2
# reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n
=
self
.
n_pack_size
*
self
.
num_n_lanes
*
reg_n
pack_k
=
self
.
k_pack_size
*
self
.
num_k_lanes
*
reg_k
if
down
:
r_packs
,
c_packs
=
r
//
pack_n
,
c
//
pack_k
else
:
c_packs
,
r_packs
=
c
//
pack_n
,
r
//
pack_k
weight
=
weight
.
view
(
c_packs
,
r_packs
,
self
.
num_n_lanes
,
self
.
num_k_lanes
,
self
.
n_pack_size
,
self
.
k_pack_size
,
reg_n
,
reg_k
)
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
# =>
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
weight
=
weight
.
permute
(
0
,
1
,
4
,
2
,
6
,
5
,
3
,
7
).
contiguous
()
weight
=
weight
.
view
(
c_packs
,
r_packs
,
pack_n
,
pack_k
)
if
down
:
weight
=
weight
.
permute
(
1
,
2
,
0
,
3
).
contiguous
().
view
(
r
,
c
)
else
:
weight
=
weight
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
c
,
r
)
return
weight
def
check_if_micro_scale
(
self
,
group_size
:
int
)
->
bool
:
return
self
.
insn_k
==
group_size
*
4
def
pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
weight
.
ndim
==
2
,
"weight tensor should be 2D."
return
pad
(
weight
,
divisor
=
(
self
.
mem_n
,
self
.
mem_k
*
self
.
num_k_unrolls
),
dim
=
(
0
,
1
))
def
pad_scale
(
self
,
scale
:
torch
.
Tensor
,
group_size
:
int
,
fill_value
:
float
=
0
)
->
torch
.
Tensor
:
if
group_size
>
0
and
scale
.
numel
()
>
scale
.
shape
[
0
]:
scale
=
scale
.
view
(
scale
.
shape
[
0
],
1
,
-
1
,
1
)
if
self
.
check_if_micro_scale
(
group_size
=
group_size
):
scale
=
pad
(
scale
,
divisor
=
(
self
.
warp_n
,
self
.
insn_k
//
group_size
),
dim
=
(
0
,
2
),
fill_value
=
fill_value
)
else
:
scale
=
pad
(
scale
,
divisor
=
(
self
.
warp_n
,
self
.
num_k_unrolls
),
dim
=
(
0
,
2
),
fill_value
=
fill_value
)
else
:
scale
=
pad
(
scale
,
divisor
=
self
.
warp_n
,
dim
=
0
,
fill_value
=
fill_value
)
return
scale
def
pad_lowrank_weight
(
self
,
weight
:
torch
.
Tensor
,
down
:
bool
)
->
torch
.
Tensor
:
assert
weight
.
ndim
==
2
,
"weight tensor should be 2D."
return
pad
(
weight
,
divisor
=
self
.
warp_n
,
dim
=
1
if
down
else
0
)
nunchaku/lora/flux/utils.py
View file @
3ef186fd
import
typing
as
tp
import
torch
import
torch
from
...utils
import
load_state_dict_in_safetensors
from
...utils
import
ceil_divide
,
load_state_dict_in_safetensors
def
detect
_format
(
lora
:
str
|
dict
[
str
,
torch
.
Tensor
])
->
str
:
def
is_nunchaku
_format
(
lora
:
str
|
dict
[
str
,
torch
.
Tensor
])
->
bool
:
if
isinstance
(
lora
,
str
):
if
isinstance
(
lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
lora
,
device
=
"cpu"
)
tensors
=
load_state_dict_in_safetensors
(
lora
,
device
=
"cpu"
)
else
:
else
:
tensors
=
lora
tensors
=
lora
for
k
in
tensors
.
keys
():
for
k
in
tensors
.
keys
():
if
"lora_unet_double_blocks_"
in
k
or
"lora_unet_single_blocks"
in
k
:
if
".mlp_fc"
in
k
or
"mlp_context_fc1"
in
k
:
return
"comfyui"
return
True
elif
".mlp_fc"
in
k
or
"mlp_context_fc1"
in
k
:
return
False
return
"svdquant"
elif
"double_blocks."
in
k
or
"single_blocks."
in
k
:
return
"xlab"
def
pad
(
elif
"transformer."
in
k
:
tensor
:
tp
.
Optional
[
torch
.
Tensor
],
return
"diffusers"
divisor
:
int
|
tp
.
Sequence
[
int
],
raise
ValueError
(
"Unknown format, please provide the format explicitly."
)
dim
:
int
|
tp
.
Sequence
[
int
],
fill_value
:
float
|
int
=
0
,
)
->
torch
.
Tensor
|
None
:
if
isinstance
(
divisor
,
int
):
if
divisor
<=
1
:
return
tensor
elif
all
(
d
<=
1
for
d
in
divisor
):
return
tensor
if
tensor
is
None
:
return
None
shape
=
list
(
tensor
.
shape
)
if
isinstance
(
dim
,
int
):
assert
isinstance
(
divisor
,
int
)
shape
[
dim
]
=
ceil_divide
(
shape
[
dim
],
divisor
)
*
divisor
else
:
if
isinstance
(
divisor
,
int
):
divisor
=
[
divisor
]
*
len
(
dim
)
for
d
,
div
in
zip
(
dim
,
divisor
,
strict
=
True
):
shape
[
d
]
=
ceil_divide
(
shape
[
d
],
div
)
*
div
result
=
torch
.
full
(
shape
,
fill_value
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
return
result
nunchaku/lora/flux/xlab_converter.py
deleted
100644 → 0
View file @
ca1a2e90
# convert the xlab lora to diffusers format
import
os
import
torch
from
safetensors.torch
import
save_file
from
...utils
import
load_state_dict_in_safetensors
def
xlab2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
else
:
tensors
=
input_lora
new_tensors
=
{}
# lora1 is for img, lora2 is for text
for
k
,
v
in
tensors
.
items
():
assert
"double_blocks"
in
k
new_k
=
k
.
replace
(
"double_blocks"
,
"transformer.transformer_blocks"
).
replace
(
"processor"
,
"attn"
)
new_k
=
new_k
.
replace
(
".down."
,
".lora_A."
)
new_k
=
new_k
.
replace
(
".up."
,
".lora_B."
)
if
".proj_lora"
in
new_k
:
new_k
=
new_k
.
replace
(
".proj_lora1"
,
".to_out.0"
)
new_k
=
new_k
.
replace
(
".proj_lora2"
,
".to_add_out"
)
new_tensors
[
new_k
]
=
v
else
:
assert
"qkv_lora"
in
new_k
if
"lora_A"
in
new_k
:
for
p
in
[
"q"
,
"k"
,
"v"
]:
if
".qkv_lora1."
in
new_k
:
new_tensors
[
new_k
.
replace
(
".qkv_lora1."
,
f
".to_
{
p
}
."
)]
=
v
.
clone
()
else
:
assert
".qkv_lora2."
in
new_k
new_tensors
[
new_k
.
replace
(
".qkv_lora2."
,
f
".add_
{
p
}
_proj."
)]
=
v
.
clone
()
else
:
assert
"lora_B"
in
new_k
for
i
,
p
in
enumerate
([
"q"
,
"k"
,
"v"
]):
assert
v
.
shape
[
0
]
%
3
==
0
chunk_size
=
v
.
shape
[
0
]
//
3
if
".qkv_lora1."
in
new_k
:
new_tensors
[
new_k
.
replace
(
".qkv_lora1."
,
f
".to_
{
p
}
."
)]
=
v
[
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
]
else
:
assert
".qkv_lora2."
in
new_k
new_tensors
[
new_k
.
replace
(
".qkv_lora2."
,
f
".add_
{
p
}
_proj."
)]
=
v
[
i
*
chunk_size
:
(
i
+
1
)
*
chunk_size
]
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
nunchaku/models/text_encoders/tinychat_utils.py
View file @
3ef186fd
...
@@ -120,4 +120,4 @@ def convert_to_tinychat_w4x16y16_linear_weight(
...
@@ -120,4 +120,4 @@ def convert_to_tinychat_w4x16y16_linear_weight(
_zero
=
torch
.
zeros
((
_ng
,
oc
),
dtype
=
dtype
,
device
=
device
)
_zero
=
torch
.
zeros
((
_ng
,
oc
),
dtype
=
dtype
,
device
=
device
)
_scale
[:
ng
]
=
scale
.
view
(
oc
,
ng
).
t
().
to
(
dtype
=
dtype
)
_scale
[:
ng
]
=
scale
.
view
(
oc
,
ng
).
t
().
to
(
dtype
=
dtype
)
_zero
[:
ng
]
=
zero
.
view
(
oc
,
ng
).
t
().
to
(
dtype
=
dtype
).
neg_
()
_zero
[:
ng
]
=
zero
.
view
(
oc
,
ng
).
t
().
to
(
dtype
=
dtype
).
neg_
()
return
_weight
,
_scale
,
_zero
return
_weight
,
_scale
,
_zero
\ No newline at end of file
nunchaku/models/transformers/transformer_flux.py
View file @
3ef186fd
import
logging
import
os
import
os
import
diffusers
import
diffusers
...
@@ -6,14 +7,24 @@ from diffusers import FluxTransformer2DModel
...
@@ -6,14 +7,24 @@ from diffusers import FluxTransformer2DModel
from
diffusers.configuration_utils
import
register_to_config
from
diffusers.configuration_utils
import
register_to_config
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
load_file
from
torch
import
nn
from
torch
import
nn
from
.utils
import
NunchakuModelLoaderMixin
,
pad_tensor
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
,
pad_tensor
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
..._C
import
QuantizedFluxModel
,
utils
as
cutils
from
...lora.flux.nunchaku_converter
import
fuse_vectors
,
to_nunchaku
from
...lora.flux.utils
import
is_nunchaku_format
from
...utils
import
load_state_dict_in_safetensors
from
...utils
import
load_state_dict_in_safetensors
SVD_RANK
=
32
SVD_RANK
=
32
# Get log level from environment variable (default to INFO)
log_level
=
os
.
getenv
(
"LOG_LEVEL"
,
"INFO"
).
upper
()
# Configure logging
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
,
logging
.
INFO
),
format
=
"%(asctime)s - %(levelname)s - %(message)s"
)
logger
=
logging
.
getLogger
(
__name__
)
class
NunchakuFluxTransformerBlocks
(
nn
.
Module
):
class
NunchakuFluxTransformerBlocks
(
nn
.
Module
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
):
...
@@ -35,9 +46,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -35,9 +46,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotemb
=
rotemb
.
permute
(
0
,
1
,
3
,
2
,
4
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
3
,
2
,
4
)
# 16*8 pack, FP32 accumulator (C) format
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
##########################################|--M--|--D--|
##########################################|--M--|--D--|
##########################################|-3--4--5--6|
##########################################|-3--4--5--6|
########################################## : : : :
########################################## : : : :
rotemb
=
rotemb
.
reshape
(
*
rotemb
.
shape
[
0
:
3
],
2
,
8
,
4
,
2
)
rotemb
=
rotemb
.
reshape
(
*
rotemb
.
shape
[
0
:
3
],
2
,
8
,
4
,
2
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
2
,
4
,
5
,
3
,
6
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
2
,
4
,
5
,
3
,
6
)
rotemb
=
rotemb
.
contiguous
()
rotemb
=
rotemb
.
contiguous
()
...
@@ -208,8 +219,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -208,8 +219,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
patch_size
=
patch_size
,
patch_size
=
patch_size
,
in_channels
=
in_channels
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
num_layers
=
0
,
num_layers
=
num_layers
,
num_single_layers
=
0
,
num_single_layers
=
num_single_layers
,
attention_head_dim
=
attention_head_dim
,
attention_head_dim
=
attention_head_dim
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
joint_attention_dim
=
joint_attention_dim
,
joint_attention_dim
=
joint_attention_dim
,
...
@@ -217,76 +228,201 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -217,76 +228,201 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
guidance_embeds
=
guidance_embeds
,
guidance_embeds
=
guidance_embeds
,
axes_dims_rope
=
axes_dims_rope
,
axes_dims_rope
=
axes_dims_rope
,
)
)
self
.
unquantized_loras
=
{}
# these state_dicts are used for supporting lora
self
.
unquantized_state_dict
=
None
self
.
_unquantized_part_sd
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_unquantized_part_loras
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_sd
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_quantized_part_vectors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
@
classmethod
@
classmethod
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
offload
=
kwargs
.
get
(
"offload"
,
False
)
offload
=
kwargs
.
get
(
"offload"
,
False
)
assert
precision
in
[
"int4"
,
"fp4"
]
precision
=
get_precision
(
kwargs
.
get
(
"precision"
,
"auto"
),
device
,
pretrained_model_name_or_path
)
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
transformer
,
unquantized_part_path
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd
=
load_file
(
transformer_block_path
)
new_quantized_part_sd
=
{}
for
k
,
v
in
quantized_part_sd
.
items
():
if
v
.
ndim
==
1
:
new_quantized_part_sd
[
k
]
=
v
elif
"qweight"
in
k
:
# only the shape information of this tensor is needed
new_quantized_part_sd
[
k
]
=
v
.
to
(
"meta"
)
elif
"lora"
in
k
:
new_quantized_part_sd
[
k
]
=
v
transformer
.
_quantized_part_sd
=
new_quantized_part_sd
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
,
use_fp4
=
precision
==
"fp4"
,
offload
=
offload
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
,
use_fp4
=
precision
==
"fp4"
,
offload
=
offload
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
to_empty
(
device
=
device
)
unquantized_part_sd
=
load_file
(
unquantized_part_path
)
transformer
.
load_state_dict
(
unquantized_part_sd
,
strict
=
False
)
transformer
.
_unquantized_part_sd
=
unquantized_part_sd
return
transformer
return
transformer
def
update_unquantized_lora_params
(
self
,
strength
:
float
=
1
):
def
inject_quantized_module
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
print
(
"Injecting quantized module"
)
self
.
pos_embed
=
EmbedND
(
dim
=
self
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
### Compatible with the original forward method
self
.
transformer_blocks
=
nn
.
ModuleList
([
NunchakuFluxTransformerBlocks
(
m
,
device
)])
self
.
single_transformer_blocks
=
nn
.
ModuleList
([])
return
self
def
set_attention_impl
(
self
,
impl
:
str
):
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setAttentionImpl
(
impl
)
### LoRA Related Functions
def
_expand_module
(
self
,
module_name
:
str
,
new_shape
:
tuple
[
int
,
int
]):
module
=
self
.
get_submodule
(
module_name
)
assert
isinstance
(
module
,
nn
.
Linear
)
weight_shape
=
module
.
weight
.
shape
logger
.
info
(
"Expand the shape of module {} from {} to {}"
.
format
(
module_name
,
tuple
(
weight_shape
),
new_shape
))
assert
new_shape
[
0
]
>=
weight_shape
[
0
]
and
new_shape
[
1
]
>=
weight_shape
[
1
]
new_module
=
nn
.
Linear
(
new_shape
[
1
],
new_shape
[
0
],
bias
=
module
.
bias
is
not
None
,
device
=
module
.
weight
.
device
,
dtype
=
module
.
weight
.
dtype
,
)
new_module
.
weight
.
data
.
zero_
()
new_module
.
weight
.
data
[:
weight_shape
[
0
],
:
weight_shape
[
1
]]
=
module
.
weight
.
data
self
.
_unquantized_part_sd
[
f
"
{
module_name
}
.weight"
]
=
new_module
.
weight
.
data
.
clone
()
if
new_module
.
bias
is
not
None
:
new_module
.
bias
.
data
.
zero_
()
new_module
.
bias
.
data
[:
weight_shape
[
0
]]
=
module
.
bias
.
data
self
.
_unquantized_part_sd
[
f
"
{
module_name
}
.bias"
]
=
new_module
.
bias
.
data
.
clone
()
parent_name
=
"."
.
join
(
module_name
.
split
(
"."
)[:
-
1
])
parent_module
=
self
.
get_submodule
(
parent_name
)
parent_module
.
add_module
(
module_name
.
split
(
"."
)[
-
1
],
new_module
)
if
module_name
==
"x_embedder"
:
new_value
=
int
(
new_module
.
weight
.
data
.
shape
[
1
])
old_value
=
getattr
(
self
.
config
,
"in_channels"
)
if
new_value
!=
old_value
:
logger
.
info
(
f
"Update in_channels from
{
old_value
}
to
{
new_value
}
"
)
setattr
(
self
.
config
,
"in_channels"
,
new_value
)
def
_update_unquantized_part_lora_params
(
self
,
strength
:
float
=
1
):
# check if we need to expand the linear layers
device
=
next
(
self
.
parameters
()).
device
for
k
,
v
in
self
.
_unquantized_part_loras
.
items
():
if
"lora_A"
in
k
:
lora_a
=
v
lora_b
=
self
.
_unquantized_part_loras
[
k
.
replace
(
".lora_A."
,
".lora_B."
)]
diff_shape
=
(
lora_b
.
shape
[
0
],
lora_a
.
shape
[
1
])
weight_shape
=
self
.
_unquantized_part_sd
[
k
.
replace
(
".lora_A."
,
"."
)].
shape
if
diff_shape
[
0
]
>
weight_shape
[
0
]
or
diff_shape
[
1
]
>
weight_shape
[
1
]:
module_name
=
"."
.
join
(
k
.
split
(
"."
)[:
-
2
])
self
.
_expand_module
(
module_name
,
diff_shape
)
elif
v
.
ndim
==
1
:
diff_shape
=
v
.
shape
weight_shape
=
self
.
_unquantized_part_sd
[
k
].
shape
if
diff_shape
[
0
]
>
weight_shape
[
0
]:
assert
diff_shape
[
0
]
>=
weight_shape
[
0
]
module_name
=
"."
.
join
(
k
.
split
(
"."
)[:
-
1
])
module
=
self
.
get_submodule
(
module_name
)
weight_shape
=
module
.
weight
.
shape
diff_shape
=
(
diff_shape
[
0
],
weight_shape
[
1
])
self
.
_expand_module
(
module_name
,
diff_shape
)
new_state_dict
=
{}
new_state_dict
=
{}
for
k
in
self
.
unquantized_state_dict
.
keys
():
for
k
in
self
.
_unquantized_part_sd
.
keys
():
v
=
self
.
unquantized_state_dict
[
k
]
v
=
self
.
_unquantized_part_sd
[
k
]
if
k
.
replace
(
".weight"
,
".lora_B.weight"
)
in
self
.
unquantized_loras
:
v
=
v
.
to
(
device
)
new_state_dict
[
k
]
=
v
+
strength
*
(
self
.
_unquantized_part_sd
[
k
]
=
v
self
.
unquantized_loras
[
k
.
replace
(
".weight"
,
".lora_B.weight"
)]
@
self
.
unquantized_loras
[
k
.
replace
(
".weight"
,
".lora_A.weight"
)]
if
v
.
ndim
==
1
and
k
in
self
.
_unquantized_part_loras
:
)
diff
=
strength
*
self
.
_unquantized_part_loras
[
k
]
if
diff
.
shape
[
0
]
<
v
.
shape
[
0
]:
diff
=
torch
.
cat
(
[
diff
,
torch
.
zeros
(
v
.
shape
[
0
]
-
diff
.
shape
[
0
],
device
=
device
,
dtype
=
v
.
dtype
)],
dim
=
0
)
new_state_dict
[
k
]
=
v
+
diff
elif
v
.
ndim
==
2
and
k
.
replace
(
".weight"
,
".lora_B.weight"
)
in
self
.
_unquantized_part_loras
:
lora_a
=
self
.
_unquantized_part_loras
[
k
.
replace
(
".weight"
,
".lora_A.weight"
)]
lora_b
=
self
.
_unquantized_part_loras
[
k
.
replace
(
".weight"
,
".lora_B.weight"
)]
if
lora_a
.
shape
[
1
]
<
v
.
shape
[
1
]:
lora_a
=
torch
.
cat
(
[
lora_a
,
torch
.
zeros
(
lora_a
.
shape
[
0
],
v
.
shape
[
1
]
-
lora_a
.
shape
[
1
],
device
=
device
,
dtype
=
v
.
dtype
),
],
dim
=
1
,
)
if
lora_b
.
shape
[
0
]
<
v
.
shape
[
0
]:
lora_b
=
torch
.
cat
(
[
lora_b
,
torch
.
zeros
(
v
.
shape
[
0
]
-
lora_b
.
shape
[
0
],
lora_b
.
shape
[
1
],
device
=
device
,
dtype
=
v
.
dtype
),
],
dim
=
0
,
)
diff
=
strength
*
(
lora_b
@
lora_a
)
new_state_dict
[
k
]
=
v
+
diff
else
:
else
:
new_state_dict
[
k
]
=
v
new_state_dict
[
k
]
=
v
self
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
self
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
def
update_lora_params
(
self
,
path_or_state_dict
:
str
|
dict
[
str
,
torch
.
Tensor
]):
def
update_lora_params
(
self
,
path_or_state_dict
:
str
|
dict
[
str
,
torch
.
Tensor
]):
if
isinstance
(
path_or_state_dict
,
dict
):
if
isinstance
(
path_or_state_dict
,
dict
):
state_dict
=
path_or_state_dict
state_dict
=
{
k
:
v
for
k
,
v
in
path_or_state_dict
.
items
()
}
# copy a new one to avoid modifying the original one
else
:
else
:
state_dict
=
load_state_dict_in_safetensors
(
path_or_state_dict
)
state_dict
=
load_state_dict_in_safetensors
(
path_or_state_dict
)
unquantized_loras
=
{}
if
not
is_nunchaku_format
(
state_dict
):
for
k
in
state_dict
.
keys
():
state_dict
=
to_nunchaku
(
state_dict
,
base_sd
=
self
.
_quantized_part_sd
)
unquantized_part_loras
=
{}
for
k
,
v
in
list
(
state_dict
.
items
()):
device
=
next
(
self
.
parameters
()).
device
if
"transformer_blocks"
not
in
k
:
if
"transformer_blocks"
not
in
k
:
unquantized_loras
[
k
]
=
state_dict
[
k
]
unquantized_part_loras
[
k
]
=
state_dict
.
pop
(
k
).
to
(
device
)
for
k
in
unquantized_loras
.
keys
():
state_dict
.
pop
(
k
)
if
len
(
self
.
_unquantized_part_loras
)
>
0
or
len
(
unquantized_part_loras
)
>
0
:
self
.
_unquantized_part_loras
=
unquantized_part_loras
self
.
_update_unquantized_part_lora_params
(
1
)
self
.
unquantized_loras
=
unquantized_loras
quantized_part_vectors
=
{}
if
len
(
unquantized_loras
)
>
0
:
for
k
,
v
in
list
(
state_dict
.
items
()):
if
self
.
unquantized_state_dict
is
None
:
if
v
.
ndim
==
1
:
unquantized_state_dict
=
self
.
state_dict
()
quantized_part_vectors
[
k
]
=
state_dict
.
pop
(
k
)
self
.
unquantized_state_dict
=
{
k
:
v
.
cpu
()
for
k
,
v
in
unquantized_state_dict
.
items
()}
if
len
(
self
.
_quantized_part_vectors
)
>
0
or
len
(
quantized_part_vectors
)
>
0
:
self
.
update_unquantized_lora_params
(
1
)
self
.
_quantized_part_vectors
=
quantized_part_vectors
updated_vectors
=
fuse_vectors
(
quantized_part_vectors
,
self
.
_quantized_part_sd
,
1
)
state_dict
.
update
(
updated_vectors
)
# Get the vectors from the quantized part
block
=
self
.
transformer_blocks
[
0
]
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
loadDict
(
state_dict
,
True
)
block
.
m
.
loadDict
(
state_dict
,
True
)
# This function can only be used with a single LoRA.
# For multiple LoRAs, please fuse the lora scale into the weights.
def
set_lora_strength
(
self
,
strength
:
float
=
1
):
def
set_lora_strength
(
self
,
strength
:
float
=
1
):
block
=
self
.
transformer_blocks
[
0
]
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setLoraScale
(
SVD_RANK
,
strength
)
block
.
m
.
setLoraScale
(
SVD_RANK
,
strength
)
if
len
(
self
.
unquantized_loras
)
>
0
:
if
len
(
self
.
_unquantized_part_loras
)
>
0
:
self
.
update_unquantized_lora_params
(
strength
)
self
.
_update_unquantized_part_lora_params
(
strength
)
if
len
(
self
.
_quantized_part_vectors
)
>
0
:
def
set_attention_impl
(
self
,
impl
:
str
):
vector_dict
=
fuse_vectors
(
self
.
_quantized_part_vectors
,
self
.
_quantized_part_sd
,
strength
)
block
=
self
.
transformer_blocks
[
0
]
block
.
m
.
loadDict
(
vector_dict
,
True
)
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setAttentionImpl
(
impl
)
def
inject_quantized_module
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
print
(
"Injecting quantized module"
)
self
.
pos_embed
=
EmbedND
(
dim
=
self
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
### Compatible with the original forward method
self
.
transformer_blocks
=
nn
.
ModuleList
([
NunchakuFluxTransformerBlocks
(
m
,
device
)])
self
.
single_transformer_blocks
=
nn
.
ModuleList
([])
return
self
nunchaku/models/transformers/transformer_sana.py
View file @
3ef186fd
...
@@ -2,13 +2,13 @@ import os
...
@@ -2,13 +2,13 @@ import os
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
from
diffusers
import
SanaTransformer2DModel
from
diffusers
import
SanaTransformer2DModel
from
diffusers.configuration_utils
import
register_to_config
from
huggingface_hub
import
utils
from
huggingface_hub
import
utils
from
safetensors.torch
import
load_file
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
.utils
import
NunchakuModelLoaderMixin
from
.utils
import
get_precision
,
NunchakuModelLoaderMixin
from
..._C
import
QuantizedSanaModel
,
utils
as
cutils
from
..._C
import
QuantizedSanaModel
,
utils
as
cutils
SVD_RANK
=
32
SVD_RANK
=
32
...
@@ -30,7 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -30,7 +30,7 @@ class NunchakuSanaTransformerBlocks(nn.Module):
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
skip_first_layer
:
Optional
[
bool
]
=
False
skip_first_layer
:
Optional
[
bool
]
=
False
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
...
@@ -77,15 +77,15 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -77,15 +77,15 @@ class NunchakuSanaTransformerBlocks(nn.Module):
)
)
def
forward_layer_at
(
def
forward_layer_at
(
self
,
self
,
idx
:
int
,
idx
:
int
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
timestep
:
Optional
[
torch
.
LongTensor
]
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
):
):
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
@@ -132,62 +132,22 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -132,62 +132,22 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class
NunchakuSanaTransformer2DModel
(
SanaTransformer2DModel
,
NunchakuModelLoaderMixin
):
class
NunchakuSanaTransformer2DModel
(
SanaTransformer2DModel
,
NunchakuModelLoaderMixin
):
@
register_to_config
def
__init__
(
self
,
in_channels
:
int
=
32
,
out_channels
:
Optional
[
int
]
=
32
,
num_attention_heads
:
int
=
70
,
attention_head_dim
:
int
=
32
,
num_layers
:
int
=
20
,
num_cross_attention_heads
:
Optional
[
int
]
=
20
,
cross_attention_head_dim
:
Optional
[
int
]
=
112
,
cross_attention_dim
:
Optional
[
int
]
=
2240
,
caption_channels
:
int
=
2304
,
mlp_ratio
:
float
=
2.5
,
dropout
:
float
=
0.0
,
attention_bias
:
bool
=
False
,
sample_size
:
int
=
32
,
patch_size
:
int
=
1
,
norm_elementwise_affine
:
bool
=
False
,
norm_eps
:
float
=
1e-6
,
interpolation_scale
:
Optional
[
int
]
=
None
,
)
->
None
:
# set num_layers to 0 to avoid creating transformer blocks
self
.
original_num_layers
=
num_layers
super
(
NunchakuSanaTransformer2DModel
,
self
).
__init__
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
num_attention_heads
=
num_attention_heads
,
attention_head_dim
=
attention_head_dim
,
num_layers
=
0
,
num_cross_attention_heads
=
num_cross_attention_heads
,
cross_attention_head_dim
=
cross_attention_head_dim
,
cross_attention_dim
=
cross_attention_dim
,
caption_channels
=
caption_channels
,
mlp_ratio
=
mlp_ratio
,
dropout
=
dropout
,
attention_bias
=
attention_bias
,
sample_size
=
sample_size
,
patch_size
=
patch_size
,
norm_elementwise_affine
=
norm_elementwise_affine
,
norm_eps
=
norm_eps
,
interpolation_scale
=
interpolation_scale
,
)
@
classmethod
@
classmethod
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
pag_layers
=
kwargs
.
get
(
"pag_layers"
,
[])
pag_layers
=
kwargs
.
get
(
"pag_layers"
,
[])
precision
=
kwargs
.
get
(
"precision"
,
"
int4"
)
precision
=
get_precision
(
kwargs
.
get
(
"precision"
,
"
auto"
),
device
,
pretrained_model_name_or_path
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
unquantized_part_path
,
transformer_block_path
=
cls
.
_build_model
(
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
pretrained_model_name_or_path
,
**
kwargs
transformer
.
config
[
"num_layers"
]
=
transformer
.
original_num_layers
)
m
=
load_quantized_module
(
m
=
load_quantized_module
(
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
,
use_fp4
=
precision
==
"fp4"
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
,
use_fp4
=
precision
==
"fp4"
)
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
to_empty
(
device
=
device
)
unquantized_state_dict
=
load_file
(
unquantized_part_path
)
transformer
.
load_state_dict
(
unquantized_state_dict
,
strict
=
False
)
return
transformer
return
transformer
def
inject_quantized_module
(
self
,
m
:
QuantizedSanaModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
def
inject_quantized_module
(
self
,
m
:
QuantizedSanaModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
...
...
nunchaku/models/transformers/utils.py
View file @
3ef186fd
import
os
import
os
import
warnings
from
typing
import
Any
,
Optional
import
torch
import
torch
from
diffusers
import
__version__
from
diffusers
import
__version__
from
huggingface_hub
import
constants
,
hf_hub_download
from
huggingface_hub
import
constants
,
hf_hub_download
from
safetensors.torch
import
load_file
from
torch
import
nn
from
typing
import
Optional
,
Any
from
nunchaku.utils
import
ceil_divide
class
NunchakuModelLoaderMixin
:
class
NunchakuModelLoaderMixin
:
@
classmethod
@
classmethod
def
_build_model
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
def
_build_model
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
)
->
tuple
[
nn
.
Module
,
str
,
str
]
:
subfolder
=
kwargs
.
get
(
"subfolder"
,
None
)
subfolder
=
kwargs
.
get
(
"subfolder"
,
None
)
if
os
.
path
.
exists
(
pretrained_model_name_or_path
):
if
os
.
path
.
exists
(
pretrained_model_name_or_path
):
dirname
=
(
dirname
=
(
...
@@ -60,16 +63,13 @@ class NunchakuModelLoaderMixin:
...
@@ -60,16 +63,13 @@ class NunchakuModelLoaderMixin:
**
kwargs
,
**
kwargs
,
)
)
transformer
=
cls
.
from_config
(
config
).
to
(
kwargs
.
get
(
"torch_dtype"
,
torch
.
bfloat16
))
with
torch
.
device
(
"meta"
):
state_dict
=
load_file
(
unquantized_part_path
)
transformer
=
cls
.
from_config
(
config
).
to
(
kwargs
.
get
(
"torch_dtype"
,
torch
.
bfloat16
))
transformer
.
load_state_dict
(
state_dict
,
strict
=
False
)
return
transformer
,
transformer_block_path
return
transformer
,
unquantized_part_path
,
transformer_block_path
def
ceil_div
(
x
:
int
,
y
:
int
)
->
int
:
return
(
x
+
y
-
1
)
//
y
def
pad_tensor
(
tensor
:
Optional
[
torch
.
Tensor
],
multiples
:
int
,
dim
:
int
,
fill
:
Any
=
0
)
->
torch
.
Tensor
:
def
pad_tensor
(
tensor
:
Optional
[
torch
.
Tensor
],
multiples
:
int
,
dim
:
int
,
fill
:
Any
=
0
)
->
torch
.
Tensor
|
None
:
if
multiples
<=
1
:
if
multiples
<=
1
:
return
tensor
return
tensor
if
tensor
is
None
:
if
tensor
is
None
:
...
@@ -77,8 +77,26 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
...
@@ -77,8 +77,26 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
shape
=
list
(
tensor
.
shape
)
shape
=
list
(
tensor
.
shape
)
if
shape
[
dim
]
%
multiples
==
0
:
if
shape
[
dim
]
%
multiples
==
0
:
return
tensor
return
tensor
shape
[
dim
]
=
ceil_div
(
shape
[
dim
],
multiples
)
*
multiples
shape
[
dim
]
=
ceil_div
ide
(
shape
[
dim
],
multiples
)
*
multiples
result
=
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
result
=
torch
.
empty
(
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
result
.
fill_
(
fill
)
result
.
fill_
(
fill
)
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
result
[[
slice
(
0
,
extent
)
for
extent
in
tensor
.
shape
]]
=
tensor
return
result
return
result
def
get_precision
(
precision
:
str
,
device
:
str
|
torch
.
device
,
pretrained_model_name_or_path
:
str
|
None
=
None
)
->
str
:
assert
precision
in
(
"auto"
,
"int4"
,
"fp4"
)
if
precision
==
"auto"
:
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
if
device
.
index
is
None
else
device
.
index
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
if
pretrained_model_name_or_path
is
not
None
:
if
precision
==
"int4"
:
if
"fp4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to fp4, but you are loading it with int4 precision."
)
elif
precision
==
"fp4"
:
if
"int4"
in
pretrained_model_name_or_path
:
warnings
.
warn
(
"The model may be quantized to int4, but you are loading it with fp4 precision."
)
return
precision
pyproject.toml
View file @
3ef186fd
...
@@ -28,4 +28,4 @@ dependencies = [
...
@@ -28,4 +28,4 @@ dependencies = [
"protobuf"
,
"protobuf"
,
"huggingface_hub"
,
"huggingface_hub"
,
]
]
requires-python
=
">=3.10
, <3.13
"
requires-python
=
">=3.10"
tests/data/MJHQ/MJHQ.py
View file @
3ef186fd
...
@@ -7,7 +7,7 @@ from PIL import Image
...
@@ -7,7 +7,7 @@ from PIL import Image
_CITATION
=
"""
\
_CITATION
=
"""
\
@misc{li2024playground,
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
year={2024},
eprint={2402.17245},
eprint={2402.17245},
...
@@ -17,7 +17,7 @@ _CITATION = """\
...
@@ -17,7 +17,7 @@ _CITATION = """\
"""
"""
_DESCRIPTION
=
"""
\
_DESCRIPTION
=
"""
\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
"""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment