Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
54e6d065
Commit
54e6d065
authored
Feb 20, 2025
by
muyangli
Browse files
[major] support NVFP4; upgrade to 0.1
parent
c7f41661
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
402 additions
and
112 deletions
+402
-112
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+8
-2
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+2
-0
nunchaku/csrc/sana.h
nunchaku/csrc/sana.h
+2
-1
nunchaku/models/transformer_flux.py
nunchaku/models/transformer_flux.py
+5
-4
nunchaku/test.py
nunchaku/test.py
+7
-1
setup.py
setup.py
+50
-9
src/FluxModel.cpp
src/FluxModel.cpp
+17
-17
src/FluxModel.h
src/FluxModel.h
+3
-3
src/Linear.cpp
src/Linear.cpp
+50
-9
src/Linear.h
src/Linear.h
+5
-1
src/SanaModel.cpp
src/SanaModel.cpp
+20
-15
src/SanaModel.h
src/SanaModel.h
+5
-4
src/Serialization.cpp
src/Serialization.cpp
+2
-0
src/Tensor.h
src/Tensor.h
+4
-1
src/common.h
src/common.h
+10
-0
src/interop/torch.cpp
src/interop/torch.cpp
+4
-0
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+4
-2
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+89
-38
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+93
-0
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+22
-5
No files found.
nunchaku/csrc/ops.h
View file @
54e6d065
...
...
@@ -28,7 +28,10 @@ namespace nunchaku::ops {
std
::
optional
<
torch
::
Tensor
>
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
...
...
@@ -63,7 +66,10 @@ namespace nunchaku::ops {
getTensor
(
out_linearattn
),
act_unsigned
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
)
);
Tensor
::
synchronizeDevice
();
}
...
...
nunchaku/csrc/pybind.cpp
View file @
54e6d065
...
...
@@ -11,6 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
)
...
...
@@ -33,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"init"
,
&
QuantizedSanaModel
::
init
,
py
::
arg
(
"config"
),
py
::
arg
(
"pag_layers"
),
py
::
arg
(
"use_fp4"
),
py
::
arg
(
"bf16"
),
py
::
arg
(
"deviceId"
)
)
...
...
nunchaku/csrc/sana.h
View file @
54e6d065
...
...
@@ -8,7 +8,7 @@
class
QuantizedSanaModel
:
public
ModuleWrapper
<
SanaModel
>
{
public:
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
bf16
,
int8_t
deviceId
)
{
void
init
(
pybind11
::
dict
config
,
std
::
vector
<
int
>
pag_layers
,
bool
use_fp4
,
bool
bf16
,
int8_t
deviceId
)
{
spdlog
::
info
(
"Initializing QuantizedSanaModel"
);
SanaConfig
cfg
{
.
num_layers
=
config
[
"num_layers"
].
cast
<
int
>
(),
...
...
@@ -17,6 +17,7 @@ public:
.
num_cross_attention_heads
=
config
[
"num_cross_attention_heads"
].
cast
<
int
>
(),
.
expand_ratio
=
config
[
"mlp_ratio"
].
cast
<
double
>
(),
.
pag_layers
=
pag_layers
,
.
use_fp4
=
use_fp4
,
};
net
=
std
::
make_unique
<
SanaModel
>
(
cfg
,
bf16
?
Tensor
::
BF16
:
Tensor
::
FP16
,
Device
::
cuda
((
int
)
deviceId
));
}
...
...
nunchaku/models/transformer_flux.py
View file @
54e6d065
...
...
@@ -108,13 +108,12 @@ class EmbedND(nn.Module):
return
emb
.
unsqueeze
(
1
)
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
)
->
QuantizedFluxModel
:
def
load_quantized_module
(
path
:
str
,
device
:
str
|
torch
.
device
=
"cuda"
,
use_fp4
:
bool
=
False
)
->
QuantizedFluxModel
:
device
=
torch
.
device
(
device
)
assert
device
.
type
==
"cuda"
m
=
QuantizedFluxModel
()
cutils
.
disable_memory_auto_release
()
m
.
init
(
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
init
(
use_fp4
,
True
,
0
if
device
.
index
is
None
else
device
.
index
)
m
.
load
(
path
)
return
m
...
...
@@ -153,8 +152,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
):
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
precision
=
kwargs
.
get
(
"precision"
,
"int4"
)
assert
precision
in
[
"int4"
,
"fp4"
]
transformer
,
transformer_block_path
=
cls
.
_build_model
(
pretrained_model_name_or_path
,
**
kwargs
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
)
m
=
load_quantized_module
(
transformer_block_path
,
device
=
device
,
use_fp4
=
precision
==
"fp4"
)
transformer
.
inject_quantized_module
(
m
,
device
)
return
transformer
...
...
nunchaku/test.py
View file @
54e6d065
...
...
@@ -4,7 +4,13 @@ from diffusers import FluxPipeline
from
.models.transformer_flux
import
NunchakuFluxTransformer2dModel
if
__name__
==
"__main__"
:
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
"mit-han-lab/svdq-int4-flux.1-schnell"
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
precision
=
"fp4"
if
sm
==
"120"
else
"int4"
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/svdq-
{
precision
}
-flux.1-schnell"
,
precision
=
precision
)
pipeline
=
FluxPipeline
.
from_pretrained
(
"black-forest-labs/FLUX.1-schnell"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
).
to
(
"cuda"
)
...
...
setup.py
View file @
54e6d065
import
os
import
re
import
subprocess
import
sys
import
setuptools
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
import
torch
from
packaging
import
version
as
packaging_version
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDA_HOME
,
CUDAExtension
class
CustomBuildExtension
(
BuildExtension
):
...
...
@@ -19,6 +24,40 @@ class CustomBuildExtension(BuildExtension):
super
().
build_extensions
()
def
get_sm_targets
()
->
list
[
str
]:
nvcc_path
=
os
.
path
.
join
(
CUDA_HOME
,
"bin/nvcc"
)
if
CUDA_HOME
else
"nvcc"
try
:
nvcc_output
=
subprocess
.
check_output
([
nvcc_path
,
"--version"
]).
decode
()
match
=
re
.
search
(
r
"release (\d+\.\d+), V(\d+\.\d+\.\d+)"
,
nvcc_output
)
if
match
:
nvcc_version
=
match
.
group
(
2
)
else
:
raise
Exception
(
"nvcc version not found"
)
print
(
f
"Found nvcc version:
{
nvcc_version
}
"
)
except
:
raise
Exception
(
"nvcc not found"
)
support_sm120
=
packaging_version
.
parse
(
nvcc_version
)
>=
packaging_version
.
parse
(
"12.8"
)
install_mode
=
os
.
getenv
(
"NUNCHAKU_INSTALL_MODE"
,
"FAST"
)
if
install_mode
==
"FAST"
:
ret
=
[]
for
i
in
range
(
torch
.
cuda
.
device_count
()):
capability
=
torch
.
cuda
.
get_device_capability
(
i
)
sm
=
f
"
{
capability
[
0
]
}{
capability
[
1
]
}
"
if
sm
==
"120"
and
support_sm120
:
sm
=
"120a"
assert
sm
in
[
"80"
,
"86"
,
"89"
,
"120a"
],
f
"Unsupported SM
{
sm
}
"
if
sm
not
in
ret
:
ret
.
append
(
sm
)
else
:
assert
install_mode
==
"ALL"
ret
=
[
"80"
,
"86"
,
"89"
]
if
support_sm120
:
ret
.
append
(
"120a"
)
return
ret
if
__name__
==
"__main__"
:
fp
=
open
(
"nunchaku/__version__.py"
,
"r"
).
read
()
version
=
eval
(
fp
.
strip
().
split
()[
-
1
])
...
...
@@ -55,12 +94,6 @@ if __name__ == "__main__":
NVCC_FLAGS
=
[
"-DENABLE_BF16=1"
,
"-DBUILD_NUNCHAKU=1"
,
"-gencode"
,
"arch=compute_86,code=sm_86"
,
"-gencode"
,
"arch=compute_89,code=sm_89"
,
# "-gencode",
# "arch=compute_89,code=sm_120a",
"-g"
,
"-std=c++20"
,
"-UNDEBUG"
,
...
...
@@ -75,13 +108,21 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--threads=
2
"
,
"--threads=
3
"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--generate-line-info"
,
"--ptxas-options=--allow-expensive-optimizations=true"
,
]
# https://github.com/NVIDIA/cutlass/pull/1479#issuecomment-2052300487
sm_targets
=
get_sm_targets
()
print
(
f
"Detected SM targets:
{
sm_targets
}
"
,
file
=
sys
.
stderr
)
assert
len
(
sm_targets
)
>
0
,
"No SM targets found"
for
target
in
sm_targets
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
target
}
,code=sm_
{
target
}
"
]
NVCC_MSVC_FLAGS
=
[
"-Xcompiler"
,
"/Zc:__cplusplus"
]
nunchaku_extension
=
CUDAExtension
(
...
...
src/FluxModel.cpp
View file @
54e6d065
...
...
@@ -259,19 +259,19 @@ void Attention::setForceFP16(Module *module, bool value) {
});
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
mlp_hidden_dim
(
dim
*
mlp_ratio
),
norm
(
dim
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
mlp_fc1
(
dim
,
mlp_hidden_dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
mlp_hidden_dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
)
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm
,
"norm"
)
...
...
@@ -327,28 +327,28 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return
hidden_states
;
}
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
JointTransformerBlock
::
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
num_heads
(
num_attention_heads
),
context_pre_only
(
context_pre_only
),
norm1
(
dim
,
false
,
dtype
,
device
),
norm1_context
(
dim
,
context_pre_only
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
dtype
,
device
),
qkv_proj
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
qkv_proj_context
(
dim
,
dim
*
3
,
true
,
use_fp4
,
dtype
,
device
),
norm_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_q
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
norm_added_k
(
dim_head
,
1e-6
,
false
,
dtype
,
device
),
attn
(
num_attention_heads
,
attention_head_dim
/
num_attention_heads
,
device
),
out_proj
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
dtype
,
device
),
out_proj
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
out_proj_context
(
dim
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
norm2
(
dim
,
1e-6
,
false
,
dtype
,
device
),
norm2_context
(
dim
,
1e-6
,
false
,
dtype
,
device
),
mlp_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
dtype
,
device
)
mlp_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc1
(
dim
,
dim
*
4
,
true
,
use_fp4
,
dtype
,
device
),
mlp_context_fc2
(
dim
*
4
,
dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
norm1
,
"norm1"
)
...
...
@@ -607,13 +607,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return
{
hidden_states
,
encoder_hidden_states
};
}
FluxModel
::
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
)
{
FluxModel
::
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
dtype
,
device
));
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
for
(
int
i
=
0
;
i
<
38
;
i
++
)
{
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
dtype
,
Device
::
cuda
()));
single_transformer_blocks
.
push_back
(
std
::
make_unique
<
FluxSingleTransformerBlock
>
(
3072
,
24
,
3072
,
4
,
use_fp4
,
dtype
,
Device
::
cuda
()));
registerChildren
(
*
single_transformer_blocks
.
back
(),
format
(
"single_transformer_blocks.{}"
,
i
));
}
}
...
...
src/FluxModel.h
View file @
54e6d065
...
...
@@ -77,7 +77,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
);
public:
...
...
@@ -101,7 +101,7 @@ public:
static
constexpr
bool
USE_4BIT
=
true
;
using
GEMM
=
std
::
conditional_t
<
USE_4BIT
,
GEMM_W4A4
,
GEMM_W8A8
>
;
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
Tensor
::
ScalarType
dtype
,
Device
device
);
JointTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
bool
context_pre_only
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
std
::
tuple
<
Tensor
,
Tensor
>
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb
,
Tensor
rotary_emb_context
,
float
sparsityRatio
);
public:
...
...
@@ -128,7 +128,7 @@ private:
class
FluxModel
:
public
Module
{
public:
FluxModel
(
Tensor
::
ScalarType
dtype
,
Device
device
);
FluxModel
(
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
);
public:
...
...
src/Linear.cpp
View file @
54e6d065
...
...
@@ -96,23 +96,33 @@ Tensor GEMV_AWQ::forward(Tensor x) {
#define NO_LORA_FUSION 0
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
GEMM_W4A4
::
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
out_features
(
out_features
),
in_features_pad
(
ceilDiv
(
in_features
,
128
)
*
128
),
out_features_pad
(
ceilDiv
(
out_features
,
128
)
*
128
),
use_fp4
(
use_fp4
),
lora_rank
(
0
),
dtype
(
dtype
)
{
this
->
qweight
=
Tensor
::
allocate
({
out_features_pad
,
in_features_pad
/
2
},
Tensor
::
INT8
,
device
,
true
);
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
if
(
use_fp4
)
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
out_features_pad
},
Tensor
::
FP8_E4M3
,
device
,
true
);
}
else
{
this
->
wscales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
out_features_pad
},
dtype
,
device
,
true
);
}
this
->
bias
=
bias
?
Tensor
::
allocate
({
out_features_pad
},
dtype
,
device
,
true
)
:
Tensor
{};
this
->
lora_down
=
Tensor
::
allocate
({
in_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
this
->
lora_up
=
Tensor
::
allocate
({
out_features_pad
,
lora_rank
},
dtype
,
device
,
true
);
// TODO: smooth factor in FC1+FC2 fusion
// TODO: smooth factor in non-Lora fusion
this
->
smooth
=
Tensor
::
allocate
({
in_features_pad
},
dtype
,
device
,
true
);
// FIXME: reset wtscale and wcscales to default values when reloading the weights
this
->
wtscale
=
Tensor
::
allocate
({
1
},
Tensor
::
FP32
,
Device
::
cpu
(),
true
);
*
this
->
wtscale
.
data_ptr
<
float
>
()
=
1.0
f
;
this
->
wcscales
=
Tensor
::
allocate
({
0
},
dtype
,
device
,
true
);
registerParams
(
qweight
,
"qweight"
)
(
wscales
,
"wscales"
)
...
...
@@ -120,6 +130,8 @@ GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, Tensor::Scala
(
lora_down
,
"lora_down"
,
ParamFlags
::
Optional
)
(
lora_up
,
"lora_up"
,
ParamFlags
::
Optional
)
(
smooth
,
"smooth"
)
(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)
(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
)
;
#if NO_LORA_FUSION
...
...
@@ -137,6 +149,21 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
else
{
dst
.
copy_
(
src
);
}
}
else
if
(
key
==
"wcscales"
)
{
assert
(
src
.
ndims
()
==
1
);
assert
(
src
.
shape
[
0
]
==
out_features_pad
);
dst
=
src
.
copy
(
this
->
qweight
.
device
());
}
else
if
(
key
==
"wtscale"
)
{
assert
(
src
.
numel
()
==
1
);
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
__nv_bfloat16
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
dst
.
copy_
(
src
);
}
else
{
assert
(
false
);
}
}
else
{
Module
::
loadParam
(
key
,
dst
,
src
);
}
...
...
@@ -167,7 +194,10 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
debug("gemm.nolora.out", out);
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
debug
(
"gemm.out"
,
out
);
#else
...
...
@@ -215,9 +245,13 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
out
=
Tensor
::
allocate
(
shape
,
dtype
,
qweight
.
device
());
}
else
{
qout
.
act
=
Tensor
::
allocate
({
M
,
out_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
if
(
use_fp4
)
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qout
.
ascales
=
Tensor
::
allocate
({
out_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qout
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qout
.
is_unsigned
=
true
;
qout
.
is_unsigned
=
!
use_fp4
;
qout
.
actShape
=
qact
.
actShape
;
next_lora
=
nextGEMM
->
lora_down
;
...
...
@@ -241,7 +275,10 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
}
#endif
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
);
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
debug
(
"gemm.out"
,
out
);
...
...
@@ -327,7 +364,11 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
QuantizedActivation
qact
;
qact
.
act
=
Tensor
::
allocate
({
M
,
in_features_pad
/
2
},
Tensor
::
INT8
,
qweight
.
device
());
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
if
(
use_fp4
)
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
16
,
M
},
Tensor
::
FP8_E4M3
,
qweight
.
device
());
}
else
{
qact
.
ascales
=
Tensor
::
allocate
({
in_features_pad
/
64
,
M
},
dtype
,
qweight
.
device
());
}
qact
.
lora_act
=
Tensor
::
allocate
({
M
,
lora_rank
},
Tensor
::
FP32
,
qweight
.
device
());
qact
.
is_unsigned
=
false
;
qact
.
actShape
=
x
.
shape
.
dataExtent
;
...
...
@@ -336,7 +377,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
debug
(
"quantize.x"
,
x
);
debug
(
"quantize.smooth"
,
this
->
smooth
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
);
kernels
::
quantize_w4a4_act_fuse_lora
(
x
,
qact
.
act
,
qact
.
ascales
,
this
->
lora_down
,
qact
.
lora_act
,
this
->
smooth
,
fuse_glu
,
use_fp4
);
debug
(
"quantize.qact"
,
qact
.
act
);
debug
(
"quantize.ascales"
,
qact
.
ascales
);
...
...
src/Linear.h
View file @
54e6d065
...
...
@@ -64,7 +64,7 @@ public:
};
public:
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
Tensor
::
ScalarType
dtype
,
Device
device
);
GEMM_W4A4
(
int
in_features
,
int
out_features
,
bool
bias
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
...
...
@@ -80,6 +80,7 @@ public:
const
int
out_features
;
const
int
in_features_pad
;
const
int
out_features_pad
;
const
bool
use_fp4
;
int
lora_rank
;
std
::
vector
<
float
>
lora_scales
;
// every 16 ranks share a scale
...
...
@@ -99,6 +100,9 @@ public:
Tensor
smooth
;
Tensor
wtscale
;
Tensor
wcscales
;
cublasHandle_t
handle
;
};
...
...
src/SanaModel.cpp
View file @
54e6d065
...
...
@@ -8,11 +8,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
dtype
,
device
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
...
...
@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
;
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
dtype
,
device
);
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
registerChildren
(
pag_to_v
.
value
(),
"pag_to_v"
);
}
}
...
...
@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
);
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{}
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
...
...
@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
)
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)
...
...
@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
dtype
,
device
)
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)
...
...
@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
...
...
@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
ceilDiv
(
int
(
round
(
config
.
expand_ratio
*
inner_dim
)),
64
)
*
64
,
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
...
...
src/SanaModel.h
View file @
54e6d065
...
...
@@ -7,7 +7,7 @@
class
SanaLinearAttention
:
public
Module
{
public:
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
out
=
{});
Tensor
forward_pag
(
Tensor
x
,
bool
cfg
);
...
...
@@ -25,7 +25,7 @@ private:
class
MultiHeadCrossAttention
:
public
Module
{
public:
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
);
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
);
...
...
@@ -41,7 +41,7 @@ private:
class
SanaGLUMBConv
:
public
Module
{
public:
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
int
H
,
int
W
);
...
...
@@ -57,7 +57,7 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
...
...
@@ -83,6 +83,7 @@ struct SanaConfig {
int
num_cross_attention_heads
;
double
expand_ratio
;
std
::
vector
<
int
>
pag_layers
;
bool
use_fp4
;
};
class
SanaModel
:
public
Module
{
...
...
src/Serialization.cpp
View file @
54e6d065
...
...
@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
};
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
...
...
src/Tensor.h
View file @
54e6d065
...
...
@@ -218,7 +218,8 @@ public:
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INT8
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
};
struct
TensorOptions
{
...
...
@@ -545,6 +546,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{
FP16
,
2
},
{
FP32
,
4
},
{
BF16
,
2
},
{
FP8_E4M3
,
1
},
{
FP8_E5M2
,
1
},
};
struct
TensorsProvider
{
...
...
src/common.h
View file @
54e6d065
...
...
@@ -9,6 +9,7 @@
#include <memory>
#include <source_location>
#include <vector>
#include <list>
#include <stack>
#include <map>
#include <unordered_map>
...
...
@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
constexpr
int
log2Up
(
T
value
)
{
if
(
value
<=
0
)
return
0
;
if
(
value
==
1
)
return
0
;
return
log2Up
((
value
+
1
)
/
2
)
+
1
;
}
struct
CUBLASWrapper
{
cublasHandle_t
handle
=
nullptr
;
...
...
src/interop/torch.cpp
View file @
54e6d065
...
...
@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) {
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
Float8_e4m3fn
,
Tensor
::
FP8_E4M3
},
{
at
::
ScalarType
::
Float8_e5m2
,
Tensor
::
FP8_E5M2
},
};
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
...
...
@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) {
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
FP8_E4M3
,
at
::
ScalarType
::
Float8_e4m3fn
},
{
Tensor
::
FP8_E5M2
,
at
::
ScalarType
::
Float8_e5m2
},
};
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
...
...
src/kernels/awq/gemv_awq.cu
View file @
54e6d065
...
...
@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
for
(
int
i
=
0
;
i
<
Num
;
++
i
)
psum
[
i
]
=
static_cast
<
accum_t
>
(
0.
f
);
extern
__shared__
uint8_t
shmem
[];
float
(
*
out_smem
)[
Num
*
kInterleave
]
=
reinterpret_cast
<
float
(
*
)[
Num
*
kInterleave
]
>
(
shmem
);
// extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
__shared__
float
out_smem
[
BlockSize
/
WARP_SIZE
*
2
][
Num
*
kInterleave
];
const
int
blk_row_offset
=
blockIdx
.
x
*
NPerBlock
*
kInterleave
;
const
int
thd_row_offset
=
(
threadIdx
.
x
/
kThreadsNumPerTile
)
%
kInterleave
;
...
...
src/kernels/zgemm/gemm_base.cuh
View file @
54e6d065
...
...
@@ -319,10 +319,10 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]);
}
out
[
i
]
=
load
_pred
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]
,
pred
);
//
}
}
}
...
...
@@ -336,12 +336,12 @@ public:
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_N_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
ptr
[
i
*
WARP_SIZE
]);
out
[
i
]
=
load
_pred
(
&
ptr
[
i
*
WARP_SIZE
]
,
pred
);
// ptr += offset;
}
//
}
}
}
...
...
@@ -352,11 +352,11 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ASCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
ASCALES_VALID_LANES
)
{
//
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out
[
i
]
=
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
];
out
[
i
]
=
load_pred
(
&
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
ASCALES_VALID_LANES
)
;
}
//
}
}
}
...
...
@@ -373,13 +373,13 @@ public:
#pragma unroll
for
(
int
i
=
0
;
i
<
WSCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
WSCALES_VALID_LANES
)
{
//
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out
[
i
]
=
load
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]);
out
[
i
]
=
load
_pred
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
WSCALES_VALID_LANES
);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
//
}
}
}
...
...
@@ -400,7 +400,7 @@ public:
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
}
template
<
typename
F
>
template
<
bool
FAST_I2F
=
false
,
typename
F
>
__device__
__forceinline__
static
void
apply_scales
(
F
&&
getpsum
,
ascale_warp
ascale
,
wscale_warp
wscale
,
fpsum_warp
&
fpsum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -429,12 +429,31 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
auto
scale_fma_normal
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
// should be faster on sm_80
auto
scale_fma_fast
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
0
]),
int2float_fast
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
2
]),
int2float_fast
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
4
]),
int2float_fast
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
6
]),
int2float_fast
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if
constexpr
(
FAST_I2F
)
{
scale_fma_fast
();
}
else
{
scale_fma_normal
();
}
#else
scale_fma_normal
();
#endif
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
...
...
@@ -575,9 +594,9 @@ public:
(
plugins
(
i
*
INSN_M
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
}
//
if (pred) {
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
//
}
}
__syncwarp
();
...
...
@@ -602,9 +621,9 @@ public:
(
plugins
(
i
*
INSN_M
+
8
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
8
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
}
//
if (pred) {
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
//
}
}
__syncwarp
();
}
...
...
@@ -680,33 +699,61 @@ public:
}
};
template
<
bool
USE_BIAS
=
true
,
bool
USE_SCALE
=
false
>
struct
EpilogueBias
{
struct
Arguments
{
const
packed_wscale_t
*
bias
;
// [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const
packed_wscale_t
*
scale
;
};
__device__
__forceinline__
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
)
{
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
,
const
packed_wscale_t
*
scale
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp
b
;
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
wscale_warp
b
,
s
;
if
constexpr
(
USE_BIAS
)
{
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
}
if
constexpr
(
USE_SCALE
)
{
load_wscale
(
scale
,
0
,
N
,
s
,
true
);
}
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
half2_t
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
half2_t
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
half2_t
b1
,
b2
;
half2_t
s1
,
s2
;
if
constexpr
(
USE_BIAS
)
{
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
}
if
constexpr
(
USE_SCALE
)
{
s1
=
broadcast_wscale
(
s
,
j
*
4
,
laneId
);
s2
=
broadcast_wscale
(
s
,
j
*
4
+
2
,
laneId
);
}
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
auto
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
if
constexpr
(
USE_SCALE
&&
USE_BIAS
)
{
fsum
.
data
[
0
]
=
__hfma2
(
fsum
.
data
[
0
],
s1
,
b1
);
fsum
.
data
[
1
]
=
__hfma2
(
fsum
.
data
[
1
],
s1
,
b1
);
fsum
.
data
[
2
]
=
__hfma2
(
fsum
.
data
[
2
],
s2
,
b2
);
fsum
.
data
[
3
]
=
__hfma2
(
fsum
.
data
[
3
],
s2
,
b2
);
}
else
if
constexpr
(
USE_SCALE
)
{
fsum
.
data
[
0
]
=
__hmul2
(
fsum
.
data
[
0
],
s1
);
fsum
.
data
[
1
]
=
__hmul2
(
fsum
.
data
[
1
],
s1
);
fsum
.
data
[
2
]
=
__hmul2
(
fsum
.
data
[
2
],
s2
);
fsum
.
data
[
3
]
=
__hmul2
(
fsum
.
data
[
3
],
s2
);
}
else
if
constexpr
(
USE_BIAS
)
{
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
}
}
}
}
...
...
@@ -714,10 +761,13 @@ public:
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bn
=
binfo
.
bn
;
apply_bias
(
fpsum
,
M
,
N
,
K
,
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
if
constexpr
(
USE_BIAS
||
USE_SCALE
)
{
apply_bias
(
fpsum
,
M
,
N
,
K
,
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
args
.
scale
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
}
};
...
...
@@ -797,7 +847,8 @@ public:
using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueBias;
template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>;
template
<
typename
kernel
,
typename
...
T
>
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
54e6d065
...
...
@@ -43,6 +43,41 @@ static T load(const T *addr) {
return
*
addr
;
}
template
<
typename
T
>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
:
"=r"
(
data
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
T
result
;
if
(
pred
)
{
result
=
*
addr
;
}
return
result
;
}
template
<
bool
shmem
=
false
,
typename
T
>
__device__
__forceinline__
static
void
store
(
T
*
addr
,
T
val
)
{
...
...
@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
*
addr
=
val
;
}
template
<
typename
T
>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
if
(
pred
)
{
*
addr
=
val
;
}
}
__device__
__forceinline__
static
float2
half22float2
(
half2
val
)
{
return
__half22float2
(
val
);
...
...
@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
}
__device__
__forceinline__
static
float
cuda_tanhf
(
float
x
)
{
float
result
;
...
...
@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
fval
-
12582912.0
f
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cu
View file @
54e6d065
...
...
@@ -36,9 +36,23 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
)
{
invoke_launch
(
ascales
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
Tensor
::
ScalarType
dtype
=
Tensor
::
INVALID_SCALAR_TYPE
;
if
(
!
fp4
)
{
dtype
=
ascales
.
dtype
();
}
else
{
for
(
auto
tensor
:
{
out
,
bias
,
lora_up
,
lora_down
,
poolout
,
wcscales
})
{
if
(
tensor
.
valid
())
{
assert
(
dtype
==
Tensor
::
INVALID_SCALAR_TYPE
||
dtype
==
tensor
.
dtype
());
dtype
=
tensor
.
dtype
();
}
}
}
invoke_launch
(
dtype
,
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
act
,
wgt
,
...
...
@@ -61,7 +75,10 @@ void gemm_w4a4(
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
wcscales
);
});
}
...
...
@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
});
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment