Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
b1b44398
"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "3becd974ed6b662983d67789ee71561da1d4351b"
Commit
b1b44398
authored
Feb 26, 2025
by
Samuel Tesfai
Browse files
Fixing merges
parents
004e4e31
4b9c2e03
Changes
55
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
290 additions
and
69 deletions
+290
-69
examples/int4-sana_1600m.py
examples/int4-sana_1600m.py
+0
-0
examples/int4-sana_1600m_pag.py
examples/int4-sana_1600m_pag.py
+0
-1
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+2
-2
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+2
-2
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+8
-2
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+2
-0
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+2
-1
nunchaku/lora/flux/comfyui_converter.py
nunchaku/lora/flux/comfyui_converter.py
+44
-4
nunchaku/lora/flux/convert.py
nunchaku/lora/flux/convert.py
+12
-4
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+43
-5
nunchaku/lora/flux/utils.py
nunchaku/lora/flux/utils.py
+21
-0
nunchaku/models/transformer_flux.py
nunchaku/models/transformer_flux.py
+5
-4
nunchaku/models/transformer_sana.py
nunchaku/models/transformer_sana.py
+7
-2
nunchaku/test.py
nunchaku/test.py
+7
-1
setup.py
setup.py
+59
-10
src/FluxModel.cpp
src/FluxModel.cpp
+17
-17
src/FluxModel.h
src/FluxModel.h
+3
-3
src/Linear.cpp
src/Linear.cpp
+50
-9
src/Linear.h
src/Linear.h
+5
-1
No files found.
examples/sana_1600m.py
→
examples/
int4-
sana_1600m.py
View file @
b1b44398
File moved
examples/sana_1600m_pag.py
→
examples/
int4-
sana_1600m_pag.py
View file @
b1b44398
...
@@ -23,6 +23,5 @@ image = pipe(
...
@@ -23,6 +23,5 @@ image = pipe(
guidance_scale
=
5.0
,
guidance_scale
=
5.0
,
pag_scale
=
2.0
,
pag_scale
=
2.0
,
num_inference_steps
=
20
,
num_inference_steps
=
20
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
).
images
[
0
]
image
.
save
(
"sana_1600m_pag.png"
)
image
.
save
(
"sana_1600m_pag.png"
)
nunchaku/__version__.py
View file @
b1b44398
__version__
=
"0.
0.2beta6
"
__version__
=
"0.
1.3
"
nunchaku/csrc/flux.h
View file @
b1b44398
...
@@ -9,9 +9,9 @@
...
@@ -9,9 +9,9 @@
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
public:
public:
void
init
(
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedFluxModel"
);
spdlog
::
info
(
"Initializing QuantizedFluxModel"
);
net
=
std
::
make_unique
<
FluxModel
>
(
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
forward
(
...
...
nunchaku/csrc/gemm.h
View file @
b1b44398
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
class
QuantizedGEMM
:
public
ModuleWrapper
<
GEMM_W4A4
>
{
class
QuantizedGEMM
:
public
ModuleWrapper
<
GEMM_W4A4
>
{
public:
public:
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
size_t
val
=
0
;
size_t
val
=
0
;
...
@@ -16,7 +16,7 @@ public:
...
@@ -16,7 +16,7 @@ public:
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
...
nunchaku/csrc/ops.h
View file @
b1b44398
...
@@ -29,7 +29,10 @@ namespace nunchaku::ops {
...
@@ -29,7 +29,10 @@ namespace nunchaku::ops {
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
)
{
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
spdlog
::
trace
(
"running gemm_w4a4: "
);
...
@@ -64,7 +67,10 @@ namespace nunchaku::ops {
...
@@ -64,7 +67,10 @@ namespace nunchaku::ops {
getTensor
(
out_linearattn
),
getTensor
(
out_linearattn
),
act_unsigned
,
act_unsigned
,
lora_scales
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
)
);
);
Tensor
::
synchronizeDevice
();
Tensor
::
synchronizeDevice
();
}
}
...
...
nunchaku/csrc/pybind.cpp
View file @
b1b44398
...
@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
...
@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"config"
),
py
::
arg
(
"config"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
py
::
arg
(
"deviceId"
)
)
)
...
...
nunchaku/csrc/sana.h
View file @
b1b44398
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
public:
public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel"
);
spdlog
::
info
(
"Initializing QuantizedSanaModel"
);
SanaConfig
cfg
{
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
...
@@ -17,6 +17,7 @@ public:
...
@@ -17,6 +17,7 @@ public:
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
};
};
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
}
...
...
nunchaku/lora/flux/comfyui_converter.py
View file @
b1b44398
# convert the comfyui lora to diffusers format
# convert the comfyui lora to diffusers format
import
argparse
import
os
import
os
import
torch
import
torch
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def
comfyui2diffusers
(
def
comfyui2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
,
min_rank
:
int
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
...
@@ -16,7 +17,7 @@ def comfyui2diffusers(
...
@@ -16,7 +17,7 @@ def comfyui2diffusers(
tensors
=
input_lora
tensors
=
input_lora
new_tensors
=
{}
new_tensors
=
{}
max_rank
=
0
for
k
,
v
in
tensors
.
items
():
for
k
,
v
in
tensors
.
items
():
if
"alpha"
in
k
:
if
"alpha"
in
k
:
continue
continue
...
@@ -29,7 +30,10 @@ def comfyui2diffusers(
...
@@ -29,7 +30,10 @@ def comfyui2diffusers(
# Copy the tensor
# Copy the tensor
new_k
=
new_k
.
replace
(
"_img_attn_qkv"
,
f
".attn.to_
{
p
}
"
)
new_k
=
new_k
.
replace
(
"_img_attn_qkv"
,
f
".attn.to_
{
p
}
"
)
new_k
=
new_k
.
replace
(
"_txt_attn_qkv"
,
f
".attn.add_
{
p
}
_proj"
)
new_k
=
new_k
.
replace
(
"_txt_attn_qkv"
,
f
".attn.add_
{
p
}
_proj"
)
new_tensors
[
new_k
]
=
v
.
clone
()
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
else
:
assert
"lora_B"
in
new_k
assert
"lora_B"
in
new_k
assert
v
.
shape
[
0
]
%
3
==
0
assert
v
.
shape
[
0
]
%
3
==
0
...
@@ -58,7 +62,10 @@ def comfyui2diffusers(
...
@@ -58,7 +62,10 @@ def comfyui2diffusers(
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
else
:
else
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
f
".attn.to_
{
p
}
"
)
new_k1
=
new_k
.
replace
(
"_linear1"
,
f
".attn.to_
{
p
}
"
)
new_tensors
[
new_k1
]
=
v
.
clone
()
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k1
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
else
:
if
p
==
"i"
:
if
p
==
"i"
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
...
@@ -70,10 +77,43 @@ def comfyui2diffusers(
...
@@ -70,10 +77,43 @@ def comfyui2diffusers(
else
:
else
:
new_k
=
new_k
.
replace
(
"_linear2"
,
".proj_out"
)
new_k
=
new_k
.
replace
(
"_linear2"
,
".proj_out"
)
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
if
"lora_down"
in
k
:
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
v
=
v
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
new_tensors
[
new_k
]
=
v
new_tensors
[
new_k
]
=
v
if
min_rank
is
not
None
:
for
k
in
new_tensors
.
keys
():
v
=
new_tensors
[
k
]
if
"lora_A"
in
k
:
rank
=
v
.
shape
[
0
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
min_rank
,
v
.
shape
[
1
],
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:
rank
]
=
v
new_tensors
[
k
]
=
new_v
else
:
assert
"lora_B"
in
k
rank
=
v
.
shape
[
1
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
v
.
shape
[
0
],
min_rank
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:,
:
rank
]
=
v
new_tensors
[
k
]
=
new_v
if
output_path
is
not
None
:
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
return
new_tensors
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensor file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensor file"
)
parser
.
add_argument
(
"--min-rank"
,
type
=
int
,
default
=
None
,
help
=
"minimum rank for the LoRA weights"
)
args
=
parser
.
parse_args
()
comfyui2diffusers
(
args
.
input_path
,
args
.
output_path
,
min_rank
=
args
.
min_rank
)
nunchaku/lora/flux/convert.py
View file @
b1b44398
...
@@ -6,6 +6,7 @@ from safetensors.torch import save_file
...
@@ -6,6 +6,7 @@ from safetensors.torch import save_file
from
.comfyui_converter
import
comfyui2diffusers
from
.comfyui_converter
import
comfyui2diffusers
from
.diffusers_converter
import
convert_to_nunchaku_flux_lowrank_dict
from
.diffusers_converter
import
convert_to_nunchaku_flux_lowrank_dict
from
.utils
import
detect_format
from
.xlab_converter
import
xlab2diffusers
from
.xlab_converter
import
xlab2diffusers
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
...
@@ -21,8 +22,8 @@ if __name__ == "__main__":
...
@@ -21,8 +22,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--lora-format"
,
"--lora-format"
,
type
=
str
,
type
=
str
,
default
=
"
diffusers
"
,
default
=
"
auto
"
,
choices
=
[
"comfyui"
,
"diffusers"
,
"xlab"
],
choices
=
[
"auto"
,
"comfyui"
,
"diffusers"
,
"xlab"
],
help
=
"format of the LoRA weights"
,
help
=
"format of the LoRA weights"
,
)
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
...
@@ -37,8 +38,8 @@ if __name__ == "__main__":
...
@@ -37,8 +38,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
not
args
.
output_root
:
if
not
args
.
output_root
:
# output to the parent directory of the
quantized model
safetensor file
# output to the parent directory of the
lora
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
quant
_path
)
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora
_path
)
if
args
.
lora_name
is
None
:
if
args
.
lora_name
is
None
:
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
...
@@ -53,6 +54,13 @@ if __name__ == "__main__":
...
@@ -53,6 +54,13 @@ if __name__ == "__main__":
orig_state_dict
=
load_state_dict_in_safetensors
(
args
.
quant_path
)
orig_state_dict
=
load_state_dict_in_safetensors
(
args
.
quant_path
)
lora_format
=
args
.
lora_format
lora_format
=
args
.
lora_format
if
lora_format
==
"auto"
:
lora_format
=
detect_format
(
args
.
lora_path
)
print
(
f
"Detected LoRA format:
{
lora_format
}
"
)
if
lora_format
==
"svdquant"
:
print
(
"Already in SVDQuant format, no conversion needed."
)
exit
(
0
)
if
lora_format
==
"diffusers"
:
if
lora_format
==
"diffusers"
:
extra_lora_dict
=
load_state_dict_in_safetensors
(
args
.
lora_path
)
extra_lora_dict
=
load_state_dict_in_safetensors
(
args
.
lora_path
)
else
:
else
:
...
...
nunchaku/lora/flux/diffusers_converter.py
View file @
b1b44398
# convert the diffusers lora to nunchaku format
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
"""Convert LoRA weights to Nunchaku format."""
import
typing
as
tp
import
typing
as
tp
import
torch
import
torch
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
update_state_dict
(
converted
,
converted
,
{
{
"lora_down"
:
lora
[
0
],
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
divisor
=
16
,
dim
=
1
),
},
},
prefix
=
converted_local_name
,
prefix
=
converted_local_name
,
)
)
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
update_state_dict
(
converted
,
converted
,
{
{
"lora_down"
:
lora
[
0
],
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
divisor
=
16
,
dim
=
1
),
},
},
prefix
=
converted_local_name
,
prefix
=
converted_local_name
,
)
)
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
for
component
in
[
"lora_A"
,
"lora_B"
]:
fc1_k
=
f
"
{
candidate_block_name
}
.proj_mlp.
{
component
}
.weight"
fc2_k
=
f
"
{
candidate_block_name
}
.proj_out.linears.1.
{
component
}
.weight"
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
extra_lora_dict
=
extra_lora_dict
,
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else
:
else
:
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
for
k
in
extra_lora_dict
.
keys
():
fc1_k
=
k
if
"ff.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff.net.0.proj"
,
"ff.net.2"
)
elif
"ff_context.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff_context.net.0.proj"
,
"ff_context.net.2"
)
else
:
continue
assert
fc2_k
in
extra_lora_dict
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
block_names
:
set
[
str
]
=
set
()
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
),
prefix
=
block_name
,
prefix
=
block_name
,
)
)
return
converted
return
converted
nunchaku/lora/flux/utils.py
0 → 100644
View file @
b1b44398
import
torch
from
...utils
import
load_state_dict_in_safetensors
def
detect_format
(
lora
:
str
|
dict
[
str
,
torch
.
Tensor
])
->
str
:
if
isinstance
(
lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
lora
,
device
=
"cpu"
)
else
:
tensors
=
lora
for
k
in
tensors
.
keys
():
if
"lora_unet_double_blocks_"
in
k
or
"lora_unet_single_blocks"
in
k
:
return
"comfyui"
elif
"mlp_fc"
in
k
or
"mlp_context_fc1"
in
k
:
return
"svdquant"
elif
"double_blocks."
in
k
or
"single_blocks."
in
k
:
return
"xlab"
elif
"transformer."
in
k
:
return
"diffusers"
raise
ValueError
(
"Unknown format, please provide the format explicitly."
)
nunchaku/models/transformer_flux.py
View file @
b1b44398
...
@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
...
@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return
emb
.
unsqueeze
(
1
)
return
emb
.
unsqueeze
(
1
)
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
)
->
QuantizedFluxModel
:
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
,
use_fp4
:
bool
=
False
)
->
QuantizedFluxModel
:
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
assert
device
.
type
==
"cuda"
assert
device
.
type
==
"cuda"
m
=
QuantizedFluxModel
()
m
=
QuantizedFluxModel
()
cutils
.
disable_memory_auto_release
()
cutils
.
disable_memory_auto_release
()
m
.
init
(
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
init
(
use_fp4
,
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
load
(
path
)
m
.
load
(
path
)
return
m
return
m
...
@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
,
use_fp4
=
precision
==
"fp4"
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
inject_quantized_module
(
m
,
device
)
return
transformer
return
transformer
...
...
nunchaku/models/transformer_sana.py
View file @
b1b44398
...
@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
...
@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
pag_layers
=
kwargs
.
get
(
"pag_layers"
,
[])
pag_layers
=
kwargs
.
get
(
"pag_layers"
,
[])
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
transformer
.
config
[
"num_layers"
]
=
transformer
.
original_num_layers
transformer
.
config
[
"num_layers"
]
=
transformer
.
original_num_layers
m
=
load_quantized_module
(
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
)
m
=
load_quantized_module
(
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
,
use_fp4
=
precision
==
"fp4"
)
transformer
.
inject_quantized_module
(
m
,
device
)
transformer
.
inject_quantized_module
(
m
,
device
)
return
transformer
return
transformer
...
@@ -140,6 +144,7 @@ def load_quantized_module(
...
@@ -140,6 +144,7 @@ def load_quantized_module(
path
:
str
,
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
,
device
:
str
|
torch
.
device
=
"cuda"
,
pag_layers
:
int
|
list
[
int
]
|
None
=
None
,
pag_layers
:
int
|
list
[
int
]
|
None
=
None
,
use_fp4
:
bool
=
False
,
)
->
QuantizedSanaModel
:
)
->
QuantizedSanaModel
:
if
pag_layers
is
None
:
if
pag_layers
is
None
:
pag_layers
=
[]
pag_layers
=
[]
...
@@ -150,7 +155,7 @@ def load_quantized_module(
...
@@ -150,7 +155,7 @@ def load_quantized_module(
m
=
QuantizedSanaModel
()
m
=
QuantizedSanaModel
()
cutils
.
disable_memory_auto_release
()
cutils
.
disable_memory_auto_release
()
m
.
init
(
net
.
config
,
pag_layers
,
net
.
dtype
==
torch
.
bfloat16
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
init
(
net
.
config
,
pag_layers
,
use_fp4
,
net
.
dtype
==
torch
.
bfloat16
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
load
(
path
)
m
.
load
(
path
)
return
m
return
m
...
...
nunchaku/test.py
View file @
b1b44398
...
@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
...
@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from
.models.transformer_flux
import
NunchakuFluxTransformer2dModel
from
.models.transformer_flux
import
NunchakuFluxTransformer2dModel
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
precision
=
precision
)
pipeline
=
FluxPipeline
.
from_pretrained
(
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
).
to
(
"cuda"
)
...
...
setup.py
View file @
b1b44398
import
os
import
os
import
re
import
subprocess
import
sys
import
setuptools
import
setuptools
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
import
torch
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
class
CustomBuildExtension
(
BuildExtension
):
def
build_extensions
(
self
):
def
build_extensions
(
self
):
...
@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension):
...
@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension):
ext
.
extra_compile_args
[
"cxx"
]
+=
ext
.
extra_compile_args
[
"gcc"
]
ext
.
extra_compile_args
[
"cxx"
]
+=
ext
.
extra_compile_args
[
"gcc"
]
super
().
build_extensions
()
super
().
build_extensions
()
def
get_sm_targets
()
->
list
[
str
]:
nvcc_path
=
os
.
path
.
join
(
CUDA_HOME
,
"bin/nvcc"
)
if
CUDA_HOME
else
"nvcc"
try
:
nvcc_output
=
subprocess
.
check_output
([
nvcc_path
,
"--version"
]).
decode
()
match
=
re
.
search
(
r
"release (\d+\.\d+), V(\d+\.\d+\.\d+)"
,
nvcc_output
)
if
match
:
nvcc_version
=
match
.
group
(
2
)
else
:
raise
Exception
(
"nvcc version not found"
)
print
(
f
"Found nvcc version:
{
nvcc_version
}
"
)
except
:
raise
Exception
(
"nvcc not found"
)
support_sm120
=
packaging_version
.
parse
(
nvcc_version
)
>=
packaging_version
.
parse
(
"12.8"
)
install_mode
=
os
.
getenv
(
"NUNCHAKU_INSTALL_MODE"
,
"FAST"
)
if
install_mode
==
"FAST"
:
ret
=
[]
for
i
in
range
(
torch
.
cuda
.
device_count
()):
capability
=
torch
.
cuda
.
get_device_capability
(
i
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
if
sm
==
"120"
and
support_sm120
:
sm
=
"120a"
assert
sm
in
[
"80"
,
"86"
,
"89"
,
"120a"
],
f
"Unsupported SM
{
sm
}
"
if
sm
not
in
ret
:
ret
.
append
(
sm
)
else
:
assert
install_mode
==
"ALL"
ret
=
[
"80"
,
"86"
,
"89"
]
if
support_sm120
:
ret
.
append
(
"120a"
)
return
ret
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
fp
=
open
(
"nunchaku/__version__.py"
,
"r"
).
read
()
fp
=
open
(
"nunchaku/__version__.py"
,
"r"
).
read
()
version
=
eval
(
fp
.
strip
().
split
()[
-
1
])
version
=
eval
(
fp
.
strip
().
split
()[
-
1
])
torch_version
=
torch
.
__version__
.
split
(
"+"
)[
0
]
torch_major_minor_version
=
"."
.
join
(
torch_version
.
split
(
"."
)[:
2
])
version
=
version
+
"+torch"
+
torch_major_minor_version
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
INCLUDE_DIRS
=
[
INCLUDE_DIRS
=
[
...
@@ -54,12 +99,6 @@ if __name__ == "__main__":
...
@@ -54,12 +99,6 @@ if __name__ == "__main__":
NVCC_FLAGS
=
[
NVCC_FLAGS
=
[
"-DENABLE_BF16=1"
,
"-DENABLE_BF16=1"
,
"-DBUILD_NUNCHAKU=1"
,
"-DBUILD_NUNCHAKU=1"
,
"-gencode"
,
"arch=compute_86,code=sm_86"
,
"-gencode"
,
"arch=compute_89,code=sm_89"
,
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g"
,
"-g"
,
"-std=c++20"
,
"-std=c++20"
,
"-UNDEBUG"
,
"-UNDEBUG"
,
...
@@ -74,13 +113,23 @@ if __name__ == "__main__":
...
@@ -74,13 +113,23 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--threads=
2
"
,
"--threads=
3
"
,
"--expt-relaxed-constexpr"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--expt-extended-lambda"
,
"--generate-line-info"
,
"--ptxas-options=--allow-expensive-optimizations=true"
,
"--ptxas-options=--allow-expensive-optimizations=true"
,
]
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
if
os
.
getenv
(
"NUNCHAKU_BUILD_WHEELS"
,
"0"
)
==
"0"
:
NVCC_FLAGS
.
append
(
"--generate-line-info"
)
sm_targets
=
get_sm_targets
()
print
(
f
"Detected SM targets:
{
sm_targets
}
"
,
file
=
sys
.
stderr
)
assert
len
(
sm_targets
)
>
0
,
"No SM targets found"
for
target
in
sm_targets
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
target
}
,code=sm_
{
target
}
"
]
NVCC_MSVC_FLAGS
=
[
"-Xcompiler"
,
"/Zc:__cplusplus"
]
NVCC_MSVC_FLAGS
=
[
"-Xcompiler"
,
"/Zc:__cplusplus"
]
nunchaku_extension
=
CUDAExtension
(
nunchaku_extension
=
CUDAExtension
(
...
...
src/FluxModel.cpp
View file @
b1b44398
...
@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
...
@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
});
});
}
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
)
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
norm
,
"norm"
)
(
norm
,
"norm"
)
...
@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
...
@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
return
hidden_states
;
}
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
)
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
norm1
,
"norm1"
)
(
norm1
,
"norm1"
)
...
@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return
{
hidden_states
,
encoder_hidden_states
};
return
{
hidden_states
,
encoder_hidden_states
};
}
}
FluxModel
::
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
}
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
dtype
,
Device
::
cuda
()));
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
Device
::
cuda
()));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
}
}
}
}
...
...
src/FluxModel.h
View file @
b1b44398
...
@@ -77,7 +77,7 @@ public:
...
@@ -77,7 +77,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
public:
...
@@ -101,7 +101,7 @@ public:
...
@@ -101,7 +101,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
public:
...
@@ -128,7 +128,7 @@ private:
...
@@ -128,7 +128,7 @@ private:
class
FluxModel
:
public
Module
{
class
FluxModel
:
public
Module
{
public:
public:
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
);
public:
public:
...
...
src/Linear.cpp
View file @
b1b44398
...
@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
...
@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
)
lora_rank
(
0
),
dtype
(
dtype
)
{
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
}
else
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
}
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features_pad
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features_pad
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
lora_down
=
Tensor
::
allocate
({
in_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_down
=
Tensor
::
allocate
({
in_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
// TODO: smooth factor in non-Lora fusion
this
->
smooth
=
Tensor
::
allocate
({
in_features_pad
},
dtype
,
device
,
true
);
this
->
smooth
=
Tensor
::
allocate
({
in_features_pad
},
dtype
,
device
,
true
);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this
->
wtscale
=
Tensor
::
allocate
({
1
},
Tensor
::
FP32
,
Device
::
cpu
(),
true
);
*
this
->
wtscale
.
data_ptr
<
float
>
()
=
1.0
f
;
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
registerParams
(
qweight
,
"qweight"
)
(
qweight
,
"qweight"
)
(
wscales
,
"wscales"
)
(
wscales
,
"wscales"
)
...
@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
...
@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
;
#if NO_LORA_FUSION
#if NO_LORA_FUSION
...
@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
...
@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
else
{
}
else
{
dst
.
copy_
(
src
);
dst
.
copy_
(
src
);
}
}
}
else
if
(
key
==
"wcscales"
)
{
assert
(
src
.
ndims
()
==
1
);
assert
(
src
.
shape
[
0
]
==
out_features_pad
);
dst
=
src
.
copy
(
this
->
qweight
.
device
());
}
else
if
(
key
==
"wtscale"
)
{
assert
(
src
.
numel
()
==
1
);
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
__nv_bfloat16
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
dst
.
copy_
(
src
);
}
else
{
assert
(
false
);
}
}
else
{
}
else
{
Module
::
loadParam
(
key
,
dst
,
src
);
Module
::
loadParam
(
key
,
dst
,
src
);
}
}
...
@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
...
@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
debug("gemm.nolora.out", out);
#endif
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
#else
#else
...
@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out
=
Tensor
::
allocate
(
shape
,
dtype
,
qweight
.
device
());
out
=
Tensor
::
allocate
(
shape
,
dtype
,
qweight
.
device
());
}
else
{
}
else
{
qout
.
act
=
Tensor
::
allocate
({
M
,
out_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qout
.
act
=
Tensor
::
allocate
({
M
,
out_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
if
(
use_fp4
)
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qout
.
is_unsigned
=
true
;
qout
.
is_unsigned
=
!
use_fp4
;
qout
.
actShape
=
qact
.
actShape
;
qout
.
actShape
=
qact
.
actShape
;
next_lora
=
nextGEMM
->
lora_down
;
next_lora
=
nextGEMM
->
lora_down
;
...
@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
...
@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
}
#endif
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
debug
(
"gemm.out"
,
out
);
...
@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation
qact
;
QuantizedActivation
qact
;
qact
.
act
=
Tensor
::
allocate
({
M
,
in_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qact
.
act
=
Tensor
::
allocate
({
M
,
in_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
if
(
use_fp4
)
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qact
.
is_unsigned
=
false
;
qact
.
is_unsigned
=
false
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
...
@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
...
...
src/Linear.h
View file @
b1b44398
...
@@ -64,7 +64,7 @@ public:
...
@@ -64,7 +64,7 @@ public:
};
};
public:
public:
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
);
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
...
@@ -80,6 +80,7 @@ public:
...
@@ -80,6 +80,7 @@ public:
const
int
out_features
;
const
int
out_features
;
const
int
in_features_pad
;
const
int
in_features_pad
;
const
int
out_features_pad
;
const
int
out_features_pad
;
const
bool
use_fp4
;
int
lora_rank
;
int
lora_rank
;
std
::
vector
<
float
>
lora_scales
;
// every 16 ranks share a scale
std
::
vector
<
float
>
lora_scales
;
// every 16 ranks share a scale
...
@@ -99,6 +100,9 @@ public:
...
@@ -99,6 +100,9 @@ public:
Tensor
smooth
;
Tensor
smooth
;
Tensor
wtscale
;
Tensor
wcscales
;
cublasHandle_t
handle
;
cublasHandle_t
handle
;
};
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment