Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
b1b44398
Commit
b1b44398
authored
Feb 26, 2025
by
Samuel Tesfai
Browse files
Fixing merges
parents
004e4e31
4b9c2e03
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
290 additions
and
69 deletions
+290
-69
examples/int4-sana_1600m.py
examples/int4-sana_1600m.py
+0
-0
examples/int4-sana_1600m_pag.py
examples/int4-sana_1600m_pag.py
+0
-1
nunchaku/__version__.py
nunchaku/__version__.py
+1
-1
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+2
-2
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+2
-2
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+8
-2
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+2
-0
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+2
-1
nunchaku/lora/flux/comfyui_converter.py
nunchaku/lora/flux/comfyui_converter.py
+44
-4
nunchaku/lora/flux/convert.py
nunchaku/lora/flux/convert.py
+12
-4
nunchaku/lora/flux/diffusers_converter.py
nunchaku/lora/flux/diffusers_converter.py
+43
-5
nunchaku/lora/flux/utils.py
nunchaku/lora/flux/utils.py
+21
-0
nunchaku/models/transformer_flux.py
nunchaku/models/transformer_flux.py
+5
-4
nunchaku/models/transformer_sana.py
nunchaku/models/transformer_sana.py
+7
-2
nunchaku/test.py
nunchaku/test.py
+7
-1
setup.py
setup.py
+59
-10
src/FluxModel.cpp
src/FluxModel.cpp
+17
-17
src/FluxModel.h
src/FluxModel.h
+3
-3
src/Linear.cpp
src/Linear.cpp
+50
-9
src/Linear.h
src/Linear.h
+5
-1
No files found.
examples/sana_1600m.py
→
examples/
int4-
sana_1600m.py
View file @
b1b44398
File moved
examples/sana_1600m_pag.py
→
examples/
int4-
sana_1600m_pag.py
View file @
b1b44398
...
...
@@ -23,6 +23,5 @@ image = pipe(
guidance_scale
=
5.0
,
pag_scale
=
2.0
,
num_inference_steps
=
20
,
generator
=
torch
.
Generator
().
manual_seed
(
42
),
).
images
[
0
]
image
.
save
(
"sana_1600m_pag.png"
)
nunchaku/__version__.py
View file @
b1b44398
__version__
=
"0.
0.2beta6
"
__version__
=
"0.
1.3
"
nunchaku/csrc/flux.h
View file @
b1b44398
...
...
@@ -9,9 +9,9 @@
class
QuantizedFluxModel
:
public
ModuleWrapper
<
FluxModel
>
{
// : public torch::CustomClassHolder {
public:
void
init
(
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedFluxModel"
);
net
=
std
::
make_unique
<
FluxModel
>
(
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
FluxModel
>
(
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
...
...
nunchaku/csrc/gemm.h
View file @
b1b44398
...
...
@@ -8,7 +8,7 @@
class
QuantizedGEMM
:
public
ModuleWrapper
<
GEMM_W4A4
>
{
public:
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
int64_t
in_features
,
int64_t
out_features
,
bool
bias
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
size_t
val
=
0
;
...
...
@@ -16,7 +16,7 @@ public:
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
(
int
)
out_features
,
bias
,
use_fp4
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
torch
::
Tensor
forward
(
torch
::
Tensor
x
)
{
...
...
nunchaku/csrc/ops.h
View file @
b1b44398
...
...
@@ -29,7 +29,10 @@ namespace nunchaku::ops {
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
...
...
@@ -64,7 +67,10 @@ namespace nunchaku::ops {
getTensor
(
out_linearattn
),
act_unsigned
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
)
);
Tensor
::
synchronizeDevice
();
}
...
...
nunchaku/csrc/pybind.cpp
View file @
b1b44398
...
...
@@ -14,6 +14,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
)
...
...
@@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"config"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
)
...
...
nunchaku/csrc/sana.h
View file @
b1b44398
...
...
@@ -8,7 +8,7 @@
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel"
);
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
...
...
@@ -17,6 +17,7 @@ public:
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
};
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
...
...
nunchaku/lora/flux/comfyui_converter.py
View file @
b1b44398
# convert the comfyui lora to diffusers format
import
argparse
import
os
import
torch
...
...
@@ -8,7 +9,7 @@ from ...utils import load_state_dict_in_safetensors
def
comfyui2diffusers
(
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
input_lora
:
str
|
dict
[
str
,
torch
.
Tensor
],
output_path
:
str
|
None
=
None
,
min_rank
:
int
|
None
=
None
)
->
dict
[
str
,
torch
.
Tensor
]:
if
isinstance
(
input_lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
input_lora
,
device
=
"cpu"
)
...
...
@@ -16,7 +17,7 @@ def comfyui2diffusers(
tensors
=
input_lora
new_tensors
=
{}
max_rank
=
0
for
k
,
v
in
tensors
.
items
():
if
"alpha"
in
k
:
continue
...
...
@@ -29,7 +30,10 @@ def comfyui2diffusers(
# Copy the tensor
new_k
=
new_k
.
replace
(
"_img_attn_qkv"
,
f
".attn.to_
{
p
}
"
)
new_k
=
new_k
.
replace
(
"_txt_attn_qkv"
,
f
".attn.add_
{
p
}
_proj"
)
new_tensors
[
new_k
]
=
v
.
clone
()
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
assert
"lora_B"
in
new_k
assert
v
.
shape
[
0
]
%
3
==
0
...
...
@@ -58,7 +62,10 @@ def comfyui2diffusers(
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
else
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
f
".attn.to_
{
p
}
"
)
new_tensors
[
new_k1
]
=
v
.
clone
()
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
new_tensors
[
new_k1
]
=
v
.
clone
()
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
else
:
if
p
==
"i"
:
new_k1
=
new_k
.
replace
(
"_linear1"
,
".proj_mlp"
)
...
...
@@ -70,10 +77,43 @@ def comfyui2diffusers(
else
:
new_k
=
new_k
.
replace
(
"_linear2"
,
".proj_out"
)
new_k
=
new_k
.
replace
(
"_modulation_lin"
,
".norm.linear"
)
if
"lora_down"
in
k
:
rank
=
v
.
shape
[
0
]
alpha
=
tensors
[
k
.
replace
(
"lora_down.weight"
,
"alpha"
)]
v
=
v
*
alpha
/
rank
max_rank
=
max
(
max_rank
,
rank
)
new_tensors
[
new_k
]
=
v
if
min_rank
is
not
None
:
for
k
in
new_tensors
.
keys
():
v
=
new_tensors
[
k
]
if
"lora_A"
in
k
:
rank
=
v
.
shape
[
0
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
min_rank
,
v
.
shape
[
1
],
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:
rank
]
=
v
new_tensors
[
k
]
=
new_v
else
:
assert
"lora_B"
in
k
rank
=
v
.
shape
[
1
]
if
rank
<
min_rank
:
new_v
=
torch
.
zeros
(
v
.
shape
[
0
],
min_rank
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
new_v
[:,
:
rank
]
=
v
new_tensors
[
k
]
=
new_v
if
output_path
is
not
None
:
output_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_path
))
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_file
(
new_tensors
,
output_path
)
return
new_tensors
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the comfyui lora safetensor file"
)
parser
.
add_argument
(
"-o"
,
"--output-path"
,
type
=
str
,
required
=
True
,
help
=
"path to the output diffusers safetensor file"
)
parser
.
add_argument
(
"--min-rank"
,
type
=
int
,
default
=
None
,
help
=
"minimum rank for the LoRA weights"
)
args
=
parser
.
parse_args
()
comfyui2diffusers
(
args
.
input_path
,
args
.
output_path
,
min_rank
=
args
.
min_rank
)
nunchaku/lora/flux/convert.py
View file @
b1b44398
...
...
@@ -6,6 +6,7 @@ from safetensors.torch import save_file
from
.comfyui_converter
import
comfyui2diffusers
from
.diffusers_converter
import
convert_to_nunchaku_flux_lowrank_dict
from
.utils
import
detect_format
from
.xlab_converter
import
xlab2diffusers
from
...utils
import
filter_state_dict
,
load_state_dict_in_safetensors
...
...
@@ -21,8 +22,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--lora-format"
,
type
=
str
,
default
=
"
diffusers
"
,
choices
=
[
"comfyui"
,
"diffusers"
,
"xlab"
],
default
=
"
auto
"
,
choices
=
[
"auto"
,
"comfyui"
,
"diffusers"
,
"xlab"
],
help
=
"format of the LoRA weights"
,
)
parser
.
add_argument
(
"--output-root"
,
type
=
str
,
default
=
""
,
help
=
"root to the output safetensor file"
)
...
...
@@ -37,8 +38,8 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
if
not
args
.
output_root
:
# output to the parent directory of the
quantized model
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
quant
_path
)
# output to the parent directory of the
lora
safetensor file
args
.
output_root
=
os
.
path
.
dirname
(
args
.
lora
_path
)
if
args
.
lora_name
is
None
:
base_name
=
os
.
path
.
basename
(
args
.
lora_path
)
lora_name
=
base_name
.
rsplit
(
"."
,
1
)[
0
]
...
...
@@ -53,6 +54,13 @@ if __name__ == "__main__":
orig_state_dict
=
load_state_dict_in_safetensors
(
args
.
quant_path
)
lora_format
=
args
.
lora_format
if
lora_format
==
"auto"
:
lora_format
=
detect_format
(
args
.
lora_path
)
print
(
f
"Detected LoRA format:
{
lora_format
}
"
)
if
lora_format
==
"svdquant"
:
print
(
"Already in SVDQuant format, no conversion needed."
)
exit
(
0
)
if
lora_format
==
"diffusers"
:
extra_lora_dict
=
load_state_dict_in_safetensors
(
args
.
lora_path
)
else
:
...
...
nunchaku/lora/flux/diffusers_converter.py
View file @
b1b44398
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
import
typing
as
tp
import
torch
...
...
@@ -215,8 +214,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
3
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
...
...
@@ -224,8 +223,8 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
update_state_dict
(
converted
,
{
"lora_down"
:
lora
[
0
],
"lora_up"
:
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
"lora_down"
:
pad
(
lora
[
0
],
divisor
=
16
,
dim
=
0
),
"lora_up"
:
pad
(
reorder_adanorm_lora_up
(
lora
[
1
],
splits
=
6
),
divisor
=
16
,
dim
=
1
),
},
prefix
=
converted_local_name
,
)
...
...
@@ -263,6 +262,22 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_A.weight"
)
extra_lora_dict
.
pop
(
f
"
{
candidate_block_name
}
.proj_out.lora_B.weight"
)
for
component
in
[
"lora_A"
,
"lora_B"
]:
fc1_k
=
f
"
{
candidate_block_name
}
.proj_mlp.
{
component
}
.weight"
fc2_k
=
f
"
{
candidate_block_name
}
.proj_out.linears.1.
{
component
}
.weight"
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
return
convert_to_nunchaku_transformer_block_lowrank_dict
(
orig_state_dict
=
orig_state_dict
,
extra_lora_dict
=
extra_lora_dict
,
...
...
@@ -347,6 +362,28 @@ def convert_to_nunchaku_flux_lowrank_dict(
else
:
extra_lora_dict
=
filter_state_dict
(
lora
,
filter_prefix
=
"transformer."
)
for
k
in
extra_lora_dict
.
keys
():
fc1_k
=
k
if
"ff.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff.net.0.proj"
,
"ff.net.2"
)
elif
"ff_context.net.0.proj"
in
k
:
fc2_k
=
k
.
replace
(
"ff_context.net.0.proj"
,
"ff_context.net.2"
)
else
:
continue
assert
fc2_k
in
extra_lora_dict
fc1_v
=
extra_lora_dict
[
fc1_k
]
fc2_v
=
extra_lora_dict
[
fc2_k
]
dim
=
0
if
"lora_A"
in
fc1_k
else
1
fc1_rank
=
fc1_v
.
shape
[
dim
]
fc2_rank
=
fc2_v
.
shape
[
dim
]
if
fc1_rank
!=
fc2_rank
:
rank
=
max
(
fc1_rank
,
fc2_rank
)
if
fc1_rank
<
rank
:
extra_lora_dict
[
fc1_k
]
=
pad
(
fc1_v
,
divisor
=
rank
,
dim
=
dim
)
if
fc2_rank
<
rank
:
extra_lora_dict
[
fc2_k
]
=
pad
(
fc2_v
,
divisor
=
rank
,
dim
=
dim
)
block_names
:
set
[
str
]
=
set
()
for
param_name
in
orig_state_dict
.
keys
():
if
param_name
.
startswith
((
"transformer_blocks."
,
"single_transformer_blocks."
)):
...
...
@@ -370,4 +407,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
),
prefix
=
block_name
,
)
return
converted
nunchaku/lora/flux/utils.py
0 → 100644
View file @
b1b44398
import
torch
from
...utils
import
load_state_dict_in_safetensors
def
detect_format
(
lora
:
str
|
dict
[
str
,
torch
.
Tensor
])
->
str
:
if
isinstance
(
lora
,
str
):
tensors
=
load_state_dict_in_safetensors
(
lora
,
device
=
"cpu"
)
else
:
tensors
=
lora
for
k
in
tensors
.
keys
():
if
"lora_unet_double_blocks_"
in
k
or
"lora_unet_single_blocks"
in
k
:
return
"comfyui"
elif
"mlp_fc"
in
k
or
"mlp_context_fc1"
in
k
:
return
"svdquant"
elif
"double_blocks."
in
k
or
"single_blocks."
in
k
:
return
"xlab"
elif
"transformer."
in
k
:
return
"diffusers"
raise
ValueError
(
"Unknown format, please provide the format explicitly."
)
nunchaku/models/transformer_flux.py
View file @
b1b44398
...
...
@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return
emb
.
unsqueeze
(
1
)
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
)
->
QuantizedFluxModel
:
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
,
use_fp4
:
bool
=
False
)
->
QuantizedFluxModel
:
device
=
torch
.
device
(
device
)
assert
device
.
type
==
"cuda"
m
=
QuantizedFluxModel
()
cutils
.
disable_memory_auto_release
()
m
.
init
(
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
init
(
use_fp4
,
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
load
(
path
)
return
m
...
...
@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
,
use_fp4
=
precision
==
"fp4"
)
transformer
.
inject_quantized_module
(
m
,
device
)
return
transformer
...
...
nunchaku/models/transformer_sana.py
View file @
b1b44398
...
...
@@ -124,9 +124,13 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
pag_layers
=
kwargs
.
get
(
"pag_layers"
,
[])
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
transformer
.
config
[
"num_layers"
]
=
transformer
.
original_num_layers
m
=
load_quantized_module
(
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
)
m
=
load_quantized_module
(
transformer
,
transformer_block_path
,
device
=
device
,
pag_layers
=
pag_layers
,
use_fp4
=
precision
==
"fp4"
)
transformer
.
inject_quantized_module
(
m
,
device
)
return
transformer
...
...
@@ -140,6 +144,7 @@ def load_quantized_module(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
,
pag_layers
:
int
|
list
[
int
]
|
None
=
None
,
use_fp4
:
bool
=
False
,
)
->
QuantizedSanaModel
:
if
pag_layers
is
None
:
pag_layers
=
[]
...
...
@@ -150,7 +155,7 @@ def load_quantized_module(
m
=
QuantizedSanaModel
()
cutils
.
disable_memory_auto_release
()
m
.
init
(
net
.
config
,
pag_layers
,
net
.
dtype
==
torch
.
bfloat16
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
init
(
net
.
config
,
pag_layers
,
use_fp4
,
net
.
dtype
==
torch
.
bfloat16
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
load
(
path
)
return
m
...
...
nunchaku/test.py
View file @
b1b44398
...
...
@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from
.models.transformer_flux
import
NunchakuFluxTransformer2dModel
if
__name__
==
"__main__"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
precision
=
precision
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
...
...
setup.py
View file @
b1b44398
import
os
import
re
import
subprocess
import
sys
import
setuptools
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
import
torch
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
def
build_extensions
(
self
):
...
...
@@ -16,10 +22,49 @@ class CustomBuildExtension(BuildExtension):
ext
.
extra_compile_args
[
"cxx"
]
+=
ext
.
extra_compile_args
[
"gcc"
]
super
().
build_extensions
()
def
get_sm_targets
()
->
list
[
str
]:
nvcc_path
=
os
.
path
.
join
(
CUDA_HOME
,
"bin/nvcc"
)
if
CUDA_HOME
else
"nvcc"
try
:
nvcc_output
=
subprocess
.
check_output
([
nvcc_path
,
"--version"
]).
decode
()
match
=
re
.
search
(
r
"release (\d+\.\d+), V(\d+\.\d+\.\d+)"
,
nvcc_output
)
if
match
:
nvcc_version
=
match
.
group
(
2
)
else
:
raise
Exception
(
"nvcc version not found"
)
print
(
f
"Found nvcc version:
{
nvcc_version
}
"
)
except
:
raise
Exception
(
"nvcc not found"
)
support_sm120
=
packaging_version
.
parse
(
nvcc_version
)
>=
packaging_version
.
parse
(
"12.8"
)
install_mode
=
os
.
getenv
(
"NUNCHAKU_INSTALL_MODE"
,
"FAST"
)
if
install_mode
==
"FAST"
:
ret
=
[]
for
i
in
range
(
torch
.
cuda
.
device_count
()):
capability
=
torch
.
cuda
.
get_device_capability
(
i
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
if
sm
==
"120"
and
support_sm120
:
sm
=
"120a"
assert
sm
in
[
"80"
,
"86"
,
"89"
,
"120a"
],
f
"Unsupported SM
{
sm
}
"
if
sm
not
in
ret
:
ret
.
append
(
sm
)
else
:
assert
install_mode
==
"ALL"
ret
=
[
"80"
,
"86"
,
"89"
]
if
support_sm120
:
ret
.
append
(
"120a"
)
return
ret
if
__name__
==
"__main__"
:
fp
=
open
(
"nunchaku/__version__.py"
,
"r"
).
read
()
version
=
eval
(
fp
.
strip
().
split
()[
-
1
])
torch_version
=
torch
.
__version__
.
split
(
"+"
)[
0
]
torch_major_minor_version
=
"."
.
join
(
torch_version
.
split
(
"."
)[:
2
])
version
=
version
+
"+torch"
+
torch_major_minor_version
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
INCLUDE_DIRS
=
[
...
...
@@ -54,12 +99,6 @@ if __name__ == "__main__":
NVCC_FLAGS
=
[
"-DENABLE_BF16=1"
,
"-DBUILD_NUNCHAKU=1"
,
"-gencode"
,
"arch=compute_86,code=sm_86"
,
"-gencode"
,
"arch=compute_89,code=sm_89"
,
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g"
,
"-std=c++20"
,
"-UNDEBUG"
,
...
...
@@ -74,13 +113,23 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--threads=
2
"
,
"--threads=
3
"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--generate-line-info"
,
"--ptxas-options=--allow-expensive-optimizations=true"
,
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
if
os
.
getenv
(
"NUNCHAKU_BUILD_WHEELS"
,
"0"
)
==
"0"
:
NVCC_FLAGS
.
append
(
"--generate-line-info"
)
sm_targets
=
get_sm_targets
()
print
(
f
"Detected SM targets:
{
sm_targets
}
"
,
file
=
sys
.
stderr
)
assert
len
(
sm_targets
)
>
0
,
"No SM targets found"
for
target
in
sm_targets
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
target
}
,code=sm_
{
target
}
"
]
NVCC_MSVC_FLAGS
=
[
"-Xcompiler"
,
"/Zc:__cplusplus"
]
nunchaku_extension
=
CUDAExtension
(
...
...
src/FluxModel.cpp
View file @
b1b44398
...
...
@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
});
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
)
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)
...
...
@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
)
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)
...
...
@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return
{
hidden_states
,
encoder_hidden_states
};
}
FluxModel
::
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
dtype
,
Device
::
cuda
()));
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
Device
::
cuda
()));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
}
}
...
...
src/FluxModel.h
View file @
b1b44398
...
...
@@ -77,7 +77,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
...
...
@@ -101,7 +101,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
...
...
@@ -128,7 +128,7 @@ private:
class
FluxModel
:
public
Module
{
public:
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
);
public:
...
...
src/Linear.cpp
View file @
b1b44398
...
...
@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
}
else
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
}
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features_pad
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
lora_down
=
Tensor
::
allocate
({
in_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
this
->
smooth
=
Tensor
::
allocate
({
in_features_pad
},
dtype
,
device
,
true
);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this
->
wtscale
=
Tensor
::
allocate
({
1
},
Tensor
::
FP32
,
Device
::
cpu
(),
true
);
*
this
->
wtscale
.
data_ptr
<
float
>
()
=
1.0
f
;
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
)
(
wscales
,
"wscales"
)
...
...
@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
#if NO_LORA_FUSION
...
...
@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
else
{
dst
.
copy_
(
src
);
}
}
else
if
(
key
==
"wcscales"
)
{
assert
(
src
.
ndims
()
==
1
);
assert
(
src
.
shape
[
0
]
==
out_features_pad
);
dst
=
src
.
copy
(
this
->
qweight
.
device
());
}
else
if
(
key
==
"wtscale"
)
{
assert
(
src
.
numel
()
==
1
);
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
__nv_bfloat16
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
dst
.
copy_
(
src
);
}
else
{
assert
(
false
);
}
}
else
{
Module
::
loadParam
(
key
,
dst
,
src
);
}
...
...
@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
debug
(
"gemm.out"
,
out
);
#else
...
...
@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out
=
Tensor
::
allocate
(
shape
,
dtype
,
qweight
.
device
());
}
else
{
qout
.
act
=
Tensor
::
allocate
({
M
,
out_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
if
(
use_fp4
)
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qout
.
is_unsigned
=
true
;
qout
.
is_unsigned
=
!
use_fp4
;
qout
.
actShape
=
qact
.
actShape
;
next_lora
=
nextGEMM
->
lora_down
;
...
...
@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
...
...
@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation
qact
;
qact
.
act
=
Tensor
::
allocate
({
M
,
in_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
if
(
use_fp4
)
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qact
.
is_unsigned
=
false
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
...
...
@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
...
...
src/Linear.h
View file @
b1b44398
...
...
@@ -64,7 +64,7 @@ public:
};
public:
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
...
...
@@ -80,6 +80,7 @@ public:
const
int
out_features
;
const
int
in_features_pad
;
const
int
out_features_pad
;
const
bool
use_fp4
;
int
lora_rank
;
std
::
vector
<
float
>
lora_scales
;
// every 16 ranks share a scale
...
...
@@ -99,6 +100,9 @@ public:
Tensor
smooth
;
Tensor
wtscale
;
Tensor
wcscales
;
cublasHandle_t
handle
;
};
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment