Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
92ac7b40
Commit
92ac7b40
authored
Mar 22, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Add our own FP16 Attention implementation
parent
182c323c
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1560 additions
and
129 deletions
+1560
-129
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+14
-2
nunchaku/csrc/ops.h
nunchaku/csrc/ops.h
+56
-2
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+7
-1
nunchaku/csrc/utils.h
nunchaku/csrc/utils.h
+7
-0
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+32
-5
setup.py
setup.py
+6
-2
src/FluxModel.cpp
src/FluxModel.cpp
+171
-70
src/FluxModel.h
src/FluxModel.h
+11
-0
src/Linear.cpp
src/Linear.cpp
+5
-3
src/Linear.h
src/Linear.h
+5
-1
src/Module.h
src/Module.h
+1
-1
src/SanaModel.cpp
src/SanaModel.cpp
+2
-1
src/kernels/zgemm/attention.cu
src/kernels/zgemm/attention.cu
+101
-0
src/kernels/zgemm/attention.cuh
src/kernels/zgemm/attention.cuh
+715
-0
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+22
-0
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+46
-34
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+347
-4
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+6
-2
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu
+1
-1
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
+5
-0
No files found.
nunchaku/csrc/flux.h
View file @
92ac7b40
...
...
@@ -143,8 +143,20 @@ public:
});
}
void
forceFP16Attention
(
bool
enable
)
{
Attention
::
setForceFP16
(
net
.
get
(),
enable
);
void
setAttentionImpl
(
std
::
string
name
)
{
if
(
name
.
empty
()
||
name
==
"default"
)
{
name
=
"flashattn2"
;
}
spdlog
::
info
(
"Set attention implementation to {}"
,
name
);
if
(
name
==
"flashattn2"
)
{
net
->
setAttentionImpl
(
AttentionImpl
::
FlashAttention2
);
}
else
if
(
name
==
"nunchaku-fp16"
)
{
net
->
setAttentionImpl
(
AttentionImpl
::
NunchakuFP16
);
}
else
{
throw
std
::
invalid_argument
(
spdlog
::
fmt_lib
::
format
(
"Invalid attention implementation {}"
,
name
));
}
}
};
\ No newline at end of file
nunchaku/csrc/ops.h
View file @
92ac7b40
...
...
@@ -32,7 +32,11 @@ namespace nunchaku::ops {
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
std
::
optional
<
torch
::
Tensor
>
wcscales
std
::
optional
<
torch
::
Tensor
>
wcscales
,
std
::
optional
<
torch
::
Tensor
>
out_q
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_k
,
// packed attention [B, H, M, D]
std
::
optional
<
torch
::
Tensor
>
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
spdlog
::
trace
(
"running gemm_w4a4: "
);
...
...
@@ -70,11 +74,31 @@ namespace nunchaku::ops {
fuse_silu
,
fp4
,
alpha
,
getTensor
(
wcscales
)
getTensor
(
wcscales
),
getTensor
(
out_q
),
getTensor
(
out_k
),
getTensor
(
out_v
),
attn_tokens
);
// Tensor::synchronizeDevice();
}
void
attention_fp16
(
torch
::
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
torch
::
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
torch
::
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
)
{
nunchaku
::
kernels
::
attention_fp16
(
from_torch
(
q
),
from_torch
(
k
),
from_torch
(
v
),
from_torch
(
o
),
scale
);
}
torch
::
Tensor
gemv_awq
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
...
...
@@ -122,6 +146,36 @@ namespace nunchaku::ops {
return
output
;
}
void
test_rmsnorm_rope
(
torch
::
Tensor
input
,
torch
::
Tensor
output
,
torch
::
Tensor
norm_q
,
torch
::
Tensor
norm_k
,
torch
::
Tensor
rotary_emb
)
{
nunchaku
::
kernels
::
test_rmsnorm_rope
(
from_torch
(
input
),
from_torch
(
output
),
from_torch
(
norm_q
),
from_torch
(
norm_k
),
from_torch
(
rotary_emb
)
);
}
void
test_pack_qkv
(
torch
::
Tensor
input
,
torch
::
Tensor
out_q
,
torch
::
Tensor
out_k
,
torch
::
Tensor
out_v
,
int
numTokens
)
{
nunchaku
::
kernels
::
test_pack_qkv
(
from_torch
(
input
),
from_torch
(
out_q
),
from_torch
(
out_k
),
from_torch
(
out_v
),
numTokens
);
}
};
\ No newline at end of file
nunchaku/csrc/pybind.cpp
View file @
92ac7b40
...
...
@@ -33,7 +33,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"
forceFP16
Attention"
,
&
QuantizedFluxModel
::
forceFP16
Attention
)
.
def
(
"
set
Attention
Impl
"
,
&
QuantizedFluxModel
::
set
Attention
Impl
)
;
py
::
class_
<
QuantizedSanaModel
>
(
m
,
"QuantizedSanaModel"
)
.
def
(
py
::
init
<>
())
...
...
@@ -82,14 +82,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
;
m
.
def_submodule
(
"ops"
)
.
def
(
"gemm_w4a4"
,
nunchaku
::
ops
::
gemm_w4a4
)
.
def
(
"attention_fp16"
,
nunchaku
::
ops
::
attention_fp16
)
.
def
(
"gemm_awq"
,
nunchaku
::
ops
::
gemm_awq
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"test_rmsnorm_rope"
,
nunchaku
::
ops
::
test_rmsnorm_rope
)
.
def
(
"test_pack_qkv"
,
nunchaku
::
ops
::
test_pack_qkv
)
;
m
.
def_submodule
(
"utils"
)
.
def
(
"set_log_level"
,
[](
const
std
::
string
&
level
)
{
spdlog
::
set_level
(
spdlog
::
level
::
from_str
(
level
));
})
.
def
(
"set_cuda_stack_limit"
,
nunchaku
::
utils
::
set_cuda_stack_limit
)
.
def
(
"disable_memory_auto_release"
,
nunchaku
::
utils
::
disable_memory_auto_release
)
.
def
(
"trim_memory"
,
nunchaku
::
utils
::
trim_memory
)
;
...
...
nunchaku/csrc/utils.h
View file @
92ac7b40
...
...
@@ -5,6 +5,13 @@
namespace
nunchaku
::
utils
{
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
checkCUDA
(
cudaDeviceSetLimit
(
cudaLimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cudaDeviceGetLimit
(
&
val
,
cudaLimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
}
void
disable_memory_auto_release
()
{
int
device
;
checkCUDA
(
cudaGetDevice
(
&
device
));
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
92ac7b40
...
...
@@ -22,6 +22,28 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self
.
dtype
=
torch
.
bfloat16
self
.
device
=
device
@
staticmethod
def
pack_rotemb
(
rotemb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
rotemb
.
dtype
==
torch
.
float32
B
=
rotemb
.
shape
[
0
]
M
=
rotemb
.
shape
[
1
]
D
=
rotemb
.
shape
[
2
]
*
2
assert
rotemb
.
shape
==
(
B
,
M
,
D
//
2
,
1
,
2
)
assert
M
%
16
==
0
assert
D
%
8
==
0
rotemb
=
rotemb
.
reshape
(
B
,
M
//
16
,
16
,
D
//
8
,
8
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
3
,
2
,
4
)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
##########################################|--M--|--D--|
##########################################|-3--4--5--6|
########################################## : : : :
rotemb
=
rotemb
.
reshape
(
*
rotemb
.
shape
[
0
:
3
],
2
,
8
,
4
,
2
)
rotemb
=
rotemb
.
permute
(
0
,
1
,
2
,
4
,
5
,
3
,
6
)
rotemb
=
rotemb
.
contiguous
()
rotemb
=
rotemb
.
view
(
B
,
M
,
D
)
return
rotemb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -53,9 +75,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
rotary_emb_single
=
image_rotary_emb
# .to(self.dtype)
rotary_emb_txt
=
pad_tensor
(
rotary_emb_txt
,
256
,
1
)
rotary_emb_img
=
pad_tensor
(
rotary_emb_img
,
256
,
1
)
rotary_emb_single
=
pad_tensor
(
rotary_emb_single
,
256
,
1
)
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
)
)
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
)
)
rotary_emb_single
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_single
,
256
,
1
)
)
hidden_states
=
self
.
m
.
forward
(
hidden_states
,
...
...
@@ -104,8 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt
=
image_rotary_emb
[:,
:
txt_tokens
,
...]
# .to(self.dtype)
rotary_emb_img
=
image_rotary_emb
[:,
txt_tokens
:,
...]
# .to(self.dtype)
rotary_emb_txt
=
pad_tensor
(
rotary_emb_txt
,
256
,
1
)
rotary_emb_img
=
pad_tensor
(
rotary_emb_img
,
256
,
1
)
rotary_emb_txt
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_txt
,
256
,
1
)
)
rotary_emb_img
=
self
.
pack_rotemb
(
pad_tensor
(
rotary_emb_img
,
256
,
1
)
)
hidden_states
,
encoder_hidden_states
=
self
.
m
.
forward_layer
(
idx
,
hidden_states
,
encoder_hidden_states
,
temb
,
rotary_emb_img
,
rotary_emb_txt
...
...
@@ -254,6 +276,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if
len
(
self
.
unquantized_loras
)
>
0
:
self
.
update_unquantized_lora_params
(
strength
)
def
set_attention_impl
(
self
,
impl
:
str
):
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setAttentionImpl
(
impl
)
def
inject_quantized_module
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
print
(
"Injecting quantized module"
)
self
.
pos_embed
=
EmbedND
(
dim
=
self
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
...
...
setup.py
View file @
92ac7b40
...
...
@@ -158,9 +158,13 @@ if __name__ == "__main__":
"src/kernels/layernorm_kernels.cu"
,
"src/kernels/misc_kernels.cu"
,
"src/kernels/zgemm/gemm_w4a4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16.cu"
,
"src/kernels/zgemm/gemm_w4a4_test.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu"
,
"src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu"
,
"src/kernels/zgemm/gemm_w8a8.cu"
,
"src/kernels/zgemm/attention.cu"
,
"src/kernels/dwconv.cu"
,
"src/kernels/gemm_batched.cu"
,
"src/kernels/gemm_f16.cu"
,
...
...
src/FluxModel.cpp
View file @
92ac7b40
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
...
...
@@ -235,7 +236,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
Tensor raw_attn_output = mha_varlen_fwd(
q, k, v,
cu_seqlens, cu_seqlens,
num_tokens_img + num_tokens_
conte
xt, num_tokens_img + num_tokens_
conte
xt,
num_tokens_img + num_tokens_
t
xt, num_tokens_img + num_tokens_
t
xt,
0.0f,
pow(q.shape[-1], (-0.5)),
false, false, -1, -1, false
...
...
@@ -298,19 +299,49 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
Tensor
residual
=
hidden_states
;
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
Tensor
attn_output
;
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Tensor
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
Tensor
qkv
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
*
3
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
// qkv_proj.forward(norm_hidden_states, qkv, {});
// debug("qkv_raw", qkv);
qkv_proj
.
forward
(
norm_hidden_states
,
qkv
,
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output
=
attn
.
forward
(
qkv
,
{},
0
);
attn_output
=
attn_output
.
reshape
({
batch_size
,
num_tokens
,
num_heads
*
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
assert
(
batch_size
==
1
);
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
k
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
Tensor
v
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_pad
,
dim_head
},
Tensor
::
FP16
,
norm_hidden_states
.
device
());
qkv_proj
.
forward
(
norm_hidden_states
,
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
q
,
k
,
v
,
num_tokens
);
debug
(
"packed_q"
,
q
);
debug
(
"packed_k"
,
k
);
debug
(
"packed_v"
,
v
);
Tensor
o
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
num_heads
*
dim_head
},
norm_hidden_states
.
scalar_type
(),
norm_hidden_states
.
device
());
kernels
::
attention_fp16
(
q
,
k
,
v
,
o
,
pow
(
dim_head
,
(
-
0.5
)));
attn_output
=
o
.
slice
(
1
,
0
,
num_tokens
);
}
else
{
assert
(
false
);
}
debug
(
"raw_attn_output"
,
attn_output
);
attn_output
=
forward_fc
(
out_proj
,
attn_output
);
debug
(
"attn_output"
,
attn_output
);
...
...
@@ -384,13 +415,13 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int
num_tokens_img
=
hidden_states
.
shape
[
1
];
int
num_tokens_
conte
xt
=
encoder_hidden_states
.
shape
[
1
];
int
num_tokens_
t
xt
=
encoder_hidden_states
.
shape
[
1
];
assert
(
hidden_states
.
shape
[
2
]
==
dim
);
assert
(
encoder_hidden_states
.
shape
[
2
]
==
dim
);
spdlog
::
debug
(
"hidden_states={} encoder_hidden_states={} temb={}"
,
hidden_states
.
shape
.
str
(),
encoder_hidden_states
.
shape
.
str
(),
temb
.
shape
.
str
());
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_
conte
xt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_
conte
xt
);
spdlog
::
debug
(
"batch_size={} num_tokens_img={} num_tokens_
t
xt={}"
,
batch_size
,
num_tokens_img
,
num_tokens_
t
xt
);
auto
norm1_output
=
norm1
.
forward
(
hidden_states
,
temb
);
auto
norm1_context_output
=
norm1_context
.
forward
(
encoder_hidden_states
,
temb
);
...
...
@@ -408,76 +439,137 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop
();
auto
stream
=
getCurrentCUDAStream
();
Tensor
concat
;
Tensor
pool
;
{
nvtxRangePushA
(
"qkv_proj"
);
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_context
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_context
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
int
num_tokens_img_pad
=
0
,
num_tokens_txt_pad
=
0
;
Tensor
raw_attn_output
;
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{}
;
if
(
attnImpl
==
AttentionImpl
::
FlashAttention2
)
{
num_tokens_img_pad
=
num_tokens_img
;
num_tokens_txt_pad
=
num_tokens_txt
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_context
);
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_context
/
POOL_SIZE
)
Tensor
concat
;
Tensor
pool
;
{
nvtxRangePushA
(
"qkv_proj"
);
const
bool
blockSparse
=
sparsityRatio
>
0
;
const
int
poolTokens
=
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
;
concat
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
pool
=
blockSparse
?
Tensor
::
allocate
({
batch_size
,
poolTokens
,
dim
*
3
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
())
:
Tensor
{};
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
Tensor
qkv
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
);
Tensor
qkv_context
=
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_txt
);
Tensor
pool_qkv
=
pool
.
valid
()
?
pool
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens_img
/
POOL_SIZE
)
:
Tensor
{};
Tensor
pool_qkv_context
=
pool
.
valid
()
?
concat
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
num_tokens_img
/
POOL_SIZE
,
num_tokens_img
/
POOL_SIZE
+
num_tokens_txt
/
POOL_SIZE
)
:
Tensor
{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
}
nvtxRangePop
();
}
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
debug
(
"concat"
,
concat
);
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
nvtxRangePushA
(
"Attention"
);
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
nvtxRangePop
();
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_txt
,
num_heads
,
dim_head
});
}
else
if
(
attnImpl
==
AttentionImpl
::
NunchakuFP16
)
{
num_tokens_img_pad
=
ceilDiv
(
num_tokens_img
,
256
)
*
256
;
num_tokens_txt_pad
=
ceilDiv
(
num_tokens_txt
,
256
)
*
256
;
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug
(
"rotary_emb"
,
rotary_emb
);
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv
,
pool_qkv
,
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
);
debug
(
"qkv"
,
qkv
);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
Tensor
concat_q
,
concat_k
,
concat_v
;
debug
(
"rotary_emb_context"
,
rotary_emb_context
);
{
nvtxRangePushA
(
"qkv_proj"
);
concat_q
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
dim_head
},
Tensor
::
FP16
,
norm1_output
.
x
.
device
());
concat_k
=
Tensor
::
empty_like
(
concat_q
);
concat_v
=
Tensor
::
empty_like
(
concat_q
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
// img first
auto
sliceImg
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
0
,
num_tokens_img_pad
);
};
auto
sliceTxt
=
[
&
](
Tensor
x
)
{
return
x
.
slice
(
0
,
i
,
i
+
1
).
slice
(
2
,
num_tokens_img_pad
,
num_tokens_img_pad
+
num_tokens_txt_pad
);
};
qkv_proj
.
forward
(
norm1_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_q
.
weight
,
norm_k
.
weight
,
rotary_emb
,
sliceImg
(
concat_q
),
sliceImg
(
concat_k
),
sliceImg
(
concat_v
),
num_tokens_img
);
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
{},
{},
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
,
sliceTxt
(
concat_q
),
sliceTxt
(
concat_k
),
sliceTxt
(
concat_v
),
num_tokens_txt
);
}
qkv_proj_context
.
forward
(
norm1_context_output
.
x
.
slice
(
0
,
i
,
i
+
1
),
qkv_context
,
pool_qkv_context
,
norm_added_q
.
weight
,
norm_added_k
.
weight
,
rotary_emb_context
);
debug
(
"qkv_context"
,
qkv_context
);
debug
(
"concat_q"
,
concat_q
);
debug
(
"concat_k"
,
concat_k
);
debug
(
"concat_v"
,
concat_v
);
nvtxRangePop
();
}
nvtxRangePop
();
}
spdlog
::
debug
(
"concat={}"
,
concat
.
shape
.
str
());
debug
(
"concat"
,
concat
);
raw_attn_output
=
Tensor
::
allocate
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
*
dim_head
},
norm1_output
.
x
.
scalar_type
(),
norm1_output
.
x
.
device
());
assert
(
concat
.
shape
[
2
]
==
num_heads
*
dim_head
*
3
);
nvtxRangePushA
(
"Attention"
);
nvtxRangePushA
(
"Attention"
);
kernels
::
attention_fp16
(
concat_q
,
concat_k
,
concat_v
,
raw_attn_output
,
pow
(
dim_head
,
(
-
0.5
))
);
Tensor
raw_attn_output
=
attn
.
forward
(
concat
,
pool
,
sparsityRatio
);
nvtxRangePop
();
nvtxRangePop
();
spdlog
::
debug
(
"raw_attn_output={}"
,
raw_attn_output
.
shape
.
str
());
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img_pad
+
num_tokens_txt_pad
,
num_heads
,
dim_head
});
}
else
{
assert
(
false
);
}
raw_attn_output
=
raw_attn_output
.
view
({
batch_size
,
num_tokens_img
+
num_tokens_context
,
num_heads
,
dim_head
});
debug
(
"raw_attn_output"
,
raw_attn_output
);
{
nvtxRangePushA
(
"o_proj"
);
auto
&&
[
_
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
norm1_output
;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_
conte
xt, num_heads * dim_head]
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_
t
xt, num_heads * dim_head]
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
...
...
@@ -488,7 +580,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
raw_attn_output_split
.
data_ptr
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
(),
(
num_tokens_img
+
num_tokens_
context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
(
num_tokens_img
_pad
+
num_tokens_
txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
...
...
@@ -546,15 +638,15 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor
raw_attn_output_split
;
if
(
batch_size
==
1
)
{
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img
,
num_tokens_img
+
num_tokens_
conte
xt
).
reshape
({
batch_size
,
num_tokens_
conte
xt
,
num_heads
*
dim_head
});
raw_attn_output_split
=
raw_attn_output
.
slice
(
1
,
num_tokens_img
_pad
,
num_tokens_img
_pad
+
num_tokens_
t
xt
).
reshape
({
batch_size
,
num_tokens_
t
xt
,
num_heads
*
dim_head
});
}
else
{
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_
conte
xt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
raw_attn_output_split
=
Tensor
::
allocate
({
batch_size
,
num_tokens_
t
xt
,
num_heads
*
dim_head
},
raw_attn_output
.
scalar_type
(),
raw_attn_output
.
device
());
checkCUDA
(
cudaMemcpy2DAsync
(
raw_attn_output_split
.
data_ptr
(),
num_tokens_
conte
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img
+
num_tokens_
context
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_
conte
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
num_tokens_
t
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
raw_attn_output
.
data_ptr
<
char
>
()
+
num_tokens_img
_pad
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
(
num_tokens_img
_pad
+
num_tokens_
txt_pad
)
*
num_heads
*
dim_head
*
raw_attn_output
.
scalar_size
(),
num_tokens_
t
xt
*
num_heads
*
dim_head
*
raw_attn_output_split
.
scalar_size
(),
batch_size
,
cudaMemcpyDeviceToDevice
,
stream
));
...
...
@@ -682,4 +774,13 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
helper
.
run
();
return
hidden_states
;
}
\ No newline at end of file
}
void
FluxModel
::
setAttentionImpl
(
AttentionImpl
impl
)
{
for
(
auto
&&
block
:
this
->
transformer_blocks
)
{
block
->
attnImpl
=
impl
;
}
for
(
auto
&&
block
:
this
->
single_transformer_blocks
)
{
block
->
attnImpl
=
impl
;
}
}
src/FluxModel.h
View file @
92ac7b40
...
...
@@ -6,6 +6,11 @@
#include "Linear.h"
#include "layernorm.h"
enum
class
AttentionImpl
{
FlashAttention2
=
0
,
NunchakuFP16
,
};
class
AdaLayerNormZeroSingle
:
public
Module
{
public:
static
constexpr
bool
USE_4BIT
=
true
;
...
...
@@ -86,6 +91,8 @@ public:
const
int
num_heads
;
const
int
mlp_hidden_dim
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
AdaLayerNormZeroSingle
norm
;
GEMM
mlp_fc1
;
...
...
@@ -110,6 +117,8 @@ public:
const
int
num_heads
;
const
bool
context_pre_only
;
AttentionImpl
attnImpl
=
AttentionImpl
::
FlashAttention2
;
private:
AdaLayerNormZero
norm1
;
AdaLayerNormZero
norm1_context
;
...
...
@@ -131,6 +140,8 @@ public:
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
temb
,
Tensor
rotary_emb_img
,
Tensor
rotary_emb_context
,
Tensor
rotary_emb_single
,
bool
skip_first_layer
=
false
);
void
setAttentionImpl
(
AttentionImpl
impl
);
public:
std
::
vector
<
std
::
unique_ptr
<
JointTransformerBlock
>>
transformer_blocks
;
std
::
vector
<
std
::
unique_ptr
<
FluxSingleTransformerBlock
>>
single_transformer_blocks
;
...
...
src/Linear.cpp
View file @
92ac7b40
...
...
@@ -181,7 +181,7 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x
return
forward_quant
(
quantize
(
x
,
false
),
fuse
,
nextGEMM
);
}
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
)
{
void
GEMM_W4A4
::
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
,
Tensor
out_q
,
Tensor
out_k
,
Tensor
out_v
,
int
numTokens
)
{
QuantizedActivation
qact
=
quantize
(
x
,
false
);
#if !NO_LORA_FUSION
...
...
@@ -196,7 +196,8 @@ void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
{},
qact
.
ascales
,
wscales
,
{},
pool
,
qact
.
lora_act
,
this
->
lora_up
,
{},
{},
norm_q
,
norm_k
,
rotary_emb
,
this
->
bias
,
{},
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
false
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
out_q
,
out_k
,
out_v
,
numTokens
);
debug
(
"gemm.out"
,
out
);
...
...
@@ -277,7 +278,8 @@ std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(Qu
kernels
::
gemm_w4a4
(
qact
.
act
,
qweight
,
out
,
qout
.
act
,
qact
.
ascales
,
wscales
,
qout
.
ascales
,
{},
qact
.
lora_act
,
this
->
lora_up
,
next_lora
,
qout
.
lora_act
,
{},
{},
{},
this
->
bias
,
next_smooth
,
{},
{},
qact
.
is_unsigned
,
this
->
lora_scales
,
fuse
==
FuseOptions
::
SILU
,
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{}
use_fp4
,
*
this
->
wtscale
.
data_ptr
<
float
>
(),
wcscales
.
numel
()
>
0
?
wcscales
:
Tensor
{},
{},
{},
{},
0
);
if
(
fuse
==
FuseOptions
::
EMPTY
||
fuse
==
FuseOptions
::
SILU
)
{
...
...
src/Linear.h
View file @
92ac7b40
...
...
@@ -69,7 +69,11 @@ public:
Tensor
forward
(
Tensor
x
);
Tensor
forward_silu
(
Tensor
x
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward
(
Tensor
x
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{});
void
forward
(
Tensor
x
,
Tensor
out
,
Tensor
pool
=
{},
Tensor
norm_q
=
{},
Tensor
norm_k
=
{},
Tensor
rotary_emb
=
{},
Tensor
out_q
=
{},
Tensor
out_k
=
{},
Tensor
out_v
=
{},
int
numTokens
=
0
);
std
::
variant
<
Tensor
,
QuantizedActivation
>
forward_quant
(
QuantizedActivation
qact
,
FuseOptions
fuse
,
GEMM_W4A4
*
nextGEMM
=
nullptr
);
Tensor
forward_quant
(
QuantizedActivation
qact
);
...
...
src/Module.h
View file @
92ac7b40
...
...
@@ -174,7 +174,7 @@ protected:
}
void
debug
(
std
::
string
name
,
Tensor
tensor
)
{
if
(
DebugContext
::
ctxs
.
empty
())
{
if
(
DebugContext
::
ctxs
.
empty
()
||
!
tensor
.
valid
()
)
{
return
;
}
std
::
string
prefix
=
getFullName
();
...
...
src/SanaModel.cpp
View file @
92ac7b40
...
...
@@ -69,7 +69,8 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{}
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{},
{},
{},
{},
0
);
debug
(
"vk"
,
vk
);
...
...
src/kernels/zgemm/attention.cu
0 → 100644
View file @
92ac7b40
#include "zgemm.h"
#include "attention.cuh"
#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif
namespace
nunchaku
::
kernels
{
void
attention_fp16
(
Tensor
q
,
// packed [Batch, Head, TokensQ, HEAD_DIM]
Tensor
k
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
v
,
// packed [Batch, Head, TokensKV, HEAD_DIM]
Tensor
o
,
// linear [Batch, TokensQ, Head * HEAD_DIM]
float
scale
)
{
int
sizeBatch
=
q
.
shape
[
0
];
int
numHeads
=
q
.
shape
[
1
];
int
numTokensQ
=
q
.
shape
[
2
];
int
headDim
=
q
.
shape
[
3
];
int
numTokensKV
=
k
.
shape
[
2
];
assert
(
o
.
ndims
()
==
3
);
assert
(
o
.
shape
[
0
]
==
sizeBatch
);
assert
(
o
.
shape
[
1
]
==
numTokensQ
);
assert
(
o
.
shape
[
2
]
==
numHeads
*
headDim
);
spdlog
::
trace
(
"attention_fp16: B={} H={} NQ={} NK={}"
,
sizeBatch
,
numHeads
,
numTokensQ
,
numTokensKV
);
spdlog
::
trace
(
"q at {}"
,
q
.
data_ptr
());
spdlog
::
trace
(
"k at {}"
,
k
.
data_ptr
());
spdlog
::
trace
(
"v at {}"
,
v
.
data_ptr
());
spdlog
::
trace
(
"o at {}"
,
o
.
data_ptr
());
spdlog
::
trace
(
"scale={}"
,
scale
);
dispatchBool
(
o
.
scalar_type
()
==
Tensor
::
BF16
,
[
&
]
<
bool
bf16out
>
()
{
#ifndef __INTELLISENSE__
using
Attention
=
typename
nunchaku
::
kernels
::
Attention
<
AttentionFP16Config
<
bf16out
>>
;
#else
using
Attention
=
typename
nunchaku
::
kernels
::
Attention
<
AttentionFP16Config
<
true
>>
;
#endif
using
GEMM
=
typename
Attention
::
GEMM
;
assert
(
isTypeMatch
<
typename
Attention
::
half_t
>
(
q
.
scalar_type
()));
assert
(
isTypeMatch
<
typename
Attention
::
half_t
>
(
k
.
scalar_type
()));
assert
(
isTypeMatch
<
typename
Attention
::
half_t
>
(
v
.
scalar_type
()));
assert
(
isTypeMatch
<
typename
Attention
::
epilogue_half_t
>
(
o
.
scalar_type
()));
int
shmem
=
0
;
// we use exp2 instead of exp in the kernel
scale
*=
M_LOG2E
;
assert
(
numTokensQ
%
Attention
::
BLOCK_M
==
0
);
assert
(
numTokensKV
%
Attention
::
WARP_K
==
0
);
assert
(
headDim
==
Attention
::
HEAD_DIM
);
auto
launch
=
[
&
]
<
typename
Epilogue
>
(
Epilogue
::
Arguments
args
)
{
dim3
grid
(
numTokensQ
/
Attention
::
BLOCK_M
,
numHeads
,
sizeBatch
);
using
packed_q_t
=
typename
Attention
::
packed_q_t
;
using
packed_k_t
=
typename
Attention
::
packed_k_t
;
using
packed_v_t
=
typename
Attention
::
packed_v_t
;
auto
func
=
invoke_kernel
<
typename
Attention
::
attention_fp16_kernel
<
Epilogue
>
,
const
packed_q_t
*
,
const
packed_k_t
*
,
const
packed_v_t
*
,
float
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
>
;
shmem
=
std
::
max
(
shmem
,
Attention
::
template
attention_fp16_kernel
<
Epilogue
>
::
SHMEM_SIZE
);
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
,
getCurrentCUDAStream
()
>>>
(
q
.
data_ptr
<
packed_q_t
>
(),
k
.
data_ptr
<
packed_k_t
>
(),
v
.
data_ptr
<
packed_v_t
>
(),
scale
,
numTokensQ
,
numTokensKV
,
args
,
false
);
checkCUDA
(
cudaGetLastError
());
};
launch
.
template
operator
()
<
typename
GEMM
::
EpilogueDefault
>(
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
o
.
data_ptr
<
typename
GEMM
::
half_t
>
(),
.
actualM
=
sizeBatch
*
numTokensQ
,
.
actualN
=
numHeads
*
headDim
,
});
});
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/attention.cuh
0 → 100644
View file @
92ac7b40
This diff is collapsed.
Click to expand it.
src/kernels/zgemm/gemm_utils.cuh
View file @
92ac7b40
...
...
@@ -188,6 +188,13 @@ static void ldmatrix(const void *ptr, uint4 &out) {
);
}
template
<
typename
T
>
__device__
__forceinline__
static
T
movmatrix
(
T
x
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
))
:
"r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
)));
return
x
;
}
// x in low bit, y in high bit
template
<
int
bitwidth
,
bool
use_unsigned
>
...
...
@@ -277,6 +284,13 @@ static float cuda_cos(float x) {
return
result
;
}
__device__
__forceinline__
static
float
cuda_exp2
(
float
x
)
{
float
result
;
asm
(
"ex2.approx.ftz.f32 %0, %1;"
:
"=f"
(
result
)
:
"f"
(
x
));
return
result
;
}
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__
__device__
static
float
cuda_sigmoidf
(
float
a
)
...
...
@@ -364,4 +378,12 @@ static float int2float_fast(int val) {
return
fval
-
12582912.0
f
;
}
template
<
typename
To
,
typename
From
>
__device__
__forceinline__
static
To
bit_cast
(
const
From
&
input
)
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
// not safe but anyway
return
*
reinterpret_cast
<
const
To
*>
(
&
input
);
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cu
View file @
92ac7b40
...
...
@@ -39,7 +39,11 @@ void gemm_w4a4(
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
Tensor
wcscales
,
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
)
{
Tensor
::
ScalarType
dtype
=
Tensor
::
INVALID_SCALAR_TYPE
;
if
(
!
fp4
)
{
...
...
@@ -53,60 +57,68 @@ void gemm_w4a4(
}
}
invoke_launch
(
dtype
,
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
act
,
wgt
,
out
,
qout
,
ascales
,
wscales
,
oscales
,
poolout
,
lora_act_in
,
lora_up
,
lora_down
,
lora_act_out
,
norm_q
,
norm_k
,
rotary_emb
,
bias
,
smooth_factor
,
out_vk
,
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
wcscales
);
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
gemm_w4a4
(
act
,
wgt
,
out
,
qout
,
ascales
,
wscales
,
oscales
,
poolout
,
lora_act_in
,
lora_up
,
lora_down
,
lora_act_out
,
norm_q
,
norm_k
,
rotary_emb
,
bias
,
smooth_factor
,
out_vk
,
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
,
fp4
,
alpha
,
wcscales
,
out_q
,
out_k
,
out_v
,
attn_tokens
);
});
});
}
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
invoke_launch
(
q
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
linearattn_vk_mul_q
(
q
,
vk
);
GEMM_W4A4_Launch
<
Config
,
false
>::
linearattn_vk_mul_q
(
q
,
vk
);
});
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
});
});
}
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act
(
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_act
(
input
,
output
,
oscales
);
});
}
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_wgt
(
GEMM_W4A4_Launch
<
Config
,
false
>::
quantize_w4a4_wgt
(
input
,
output
,
oscales
);
});
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
92ac7b40
...
...
@@ -1618,13 +1618,300 @@ public:
}
};
struct
EpilogueLiteLA
{
struct
EpilogueRMSNormRope
{
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
NUM_HEADS_PER_WARP
=
WARP_N
/
HEAD_DIM
;
static
constexpr
int
WARP_N_TILES_PER_HEAD
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
;
static
constexpr
int
ROTARY_EMB_NUM_ELEMENTS
=
2
;
using
packed_rotemb_t
=
float4
;
static
constexpr
int
WARP_N_ROTEMB_TILES
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
*
2
;
using
rotemb_warp
=
std
::
array
<
packed_rotemb_t
,
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
>
;
// 128 regs
struct
Arguments
{
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const
packed_rotemb_t
*
rotary_emb
;
const
half_t
*
rmsnorm_weight_q
;
// [HEAD_DIM]
const
half_t
*
rmsnorm_weight_k
;
// [HEAD_DIM]
float
epsilon
;
};
__device__
__forceinline__
static
rotemb_warp
load_rotemb
(
const
packed_rotemb_t
*
ptr_rotemb
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
rotemb_warp
rotemb
;
const
packed_rotemb_t
*
ptrlane
=
&
ptr_rotemb
[
warpId
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
i
>
()
{
unrolled_loop
<
WARP_N_ROTEMB_TILES
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
(
i
*
WARP_N_ROTEMB_TILES
+
j
)
*
WARP_SIZE
;
rotemb
[
i
*
WARP_N_ROTEMB_TILES
+
j
]
=
load
(
&
ptrlane
[
offset
]);
});
});
return
rotemb
;
}
__device__
__forceinline__
static
void
load_rmsnorm
(
const
half_t
*
ptr_rmsnorm_weight
,
half_t
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
static
constexpr
int
PACK_SIZE
=
HEAD_DIM
/
WARP_SIZE
;
using
packed_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_t
pack
=
load
(
reinterpret_cast
<
const
packed_t
*>
(
ptr_rmsnorm_weight
+
laneId
*
PACK_SIZE
));
store
<
true
>
(
reinterpret_cast
<
packed_t
*>
(
shmem
+
laneId
*
PACK_SIZE
),
pack
);
}
__device__
__forceinline__
static
packed_fpsum_t
load_rmsnorm_from_shmem
(
half_t
*
shmem
,
int
n
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
// lane 0-15: n*16+0, lane 16-31: n*16+8
uint4
tmp
;
ldmatrix
(
shmem
+
col
,
tmp
);
return
bit_cast
<
packed_fpsum_t
>
(
tmp
);
}
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
const
packed_rotemb_t
*
ptr_rotemb
,
const
half_t
*
ptr_rmsnorm_weight
,
float
epsilon
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
half_t
shmem_rmsnorm
[
NUM_WARPS
][
HEAD_DIM
];
load_rmsnorm
(
ptr_rmsnorm_weight
,
&
shmem_rmsnorm
[
warpId
][
0
]);
__syncwarp
();
rotemb_warp
rotemb
=
load_rotemb
(
ptr_rotemb
);
float
rmsnorm_coef
[
NUM_HEADS_PER_WARP
][
WARP_M_TILES
][
2
];
auto
sqr
=
[](
half2_t
val
)
ALWAYSINLINE
{
float2
fval
=
half22float2
(
val
);
return
fval
.
x
*
fval
.
x
+
fval
.
y
*
fval
.
y
;
};
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
float
sqrsum
[
2
]
=
{
0.0
f
,
0.0
f
};
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
0
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
1
]);
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
2
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
3
]);
}
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
sqrsum
[
0
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
0
],
mask
);
sqrsum
[
1
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
1
],
mask
);
}
rmsnorm_coef
[
head
][
m
][
0
]
=
cuda_frsqrt
(
sqrsum
[
0
]
/
HEAD_DIM
+
epsilon
);
rmsnorm_coef
[
head
][
m
][
1
]
=
cuda_frsqrt
(
sqrsum
[
1
]
/
HEAD_DIM
+
epsilon
);
}
}
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
packed_f32psum_t
rms
=
packed_fp16_to_fp32
(
load_rmsnorm_from_shmem
(
&
shmem_rmsnorm
[
warpId
][
0
],
n
));
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
packed_f32psum_t
pack
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]);
pack
.
data
[
0
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
0
];
pack
.
data
[
1
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
1
];
pack
.
data
[
2
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
2
];
pack
.
data
[
3
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
3
];
pack
.
data
[
4
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
4
];
pack
.
data
[
5
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
5
];
pack
.
data
[
6
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
6
];
pack
.
data
[
7
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
7
];
auto
rope
=
[](
float
&
x
,
float
&
y
,
float
sin
,
float
cos
)
ALWAYSINLINE
{
float
ix
=
x
,
iy
=
y
;
x
=
ix
*
cos
-
iy
*
sin
;
y
=
ix
*
sin
+
iy
*
cos
;
};
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
];
rope
(
pack
.
data
[
0
],
pack
.
data
[
1
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
2
],
pack
.
data
[
3
],
sincos
.
z
,
sincos
.
w
);
}
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
+
1
];
rope
(
pack
.
data
[
4
],
pack
.
data
[
5
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
6
],
pack
.
data
[
7
],
sincos
.
z
,
sincos
.
w
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]
=
packed_fp32_to_fp16
(
pack
);
}
}
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
bool
is_q
=
bn
<
binfo
.
numBlocksN
/
3
;
const
bool
is_k
=
!
is_q
&&
bn
<
binfo
.
numBlocksN
/
3
*
2
;
if
(
is_q
||
is_k
)
{
apply
(
fpsum
,
args
.
rotary_emb
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
,
is_q
?
args
.
rmsnorm_weight_q
:
args
.
rmsnorm_weight_k
,
args
.
epsilon
);
}
}
};
struct
EpiloguePackQKV
{
using
attn_half_t
=
half
;
using
attn_half2_t
=
half2
;
using
packed_qkv_t
=
uint4
;
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
INSN_K_QK
=
16
;
static
constexpr
int
INSN_K_PV
=
16
;
struct
Arguments
{
packed_qkv_t
*
out_q
,
*
out_k
,
*
out_v
;
int
actualM
;
// !!! stride in number of packed_qkv_t !!!
int
strideHead_q
;
int
strideHead_k
;
int
strideHead_v
;
};
__device__
__forceinline__
static
attn_half2_t
convert_half2
(
half2_t
input
)
{
if
constexpr
(
std
::
is_same_v
<
half2_t
,
attn_half2_t
>
)
{
return
input
;
}
else
{
float2
fval
=
half22float2
(
input
);
return
float22half2
<
attn_half2_t
>
(
fval
);
}
}
__device__
__forceinline__
static
packed_qkv_t
pack_q
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
packed_qkv_t
pack_k
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
half2_t
movmatrix
(
half2_t
x
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
))
:
"r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
)));
return
x
;
static
packed_qkv_t
pack_v
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
0
])));
output
.
y
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
1
])));
output
.
z
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
2
])));
output
.
w
=
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
3
])));
return
output
;
}
__device__
__forceinline__
static
void
mask
(
packed_qkv_t
&
pack
,
uint32_t
maskVal
,
int
m
,
int
maxRows
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
if
(
m
*
INSN_M
+
laneId
/
4
>=
maxRows
)
{
pack
.
x
=
maskVal
;
pack
.
z
=
maskVal
;
}
if
(
m
*
INSN_M
+
laneId
/
4
+
8
>=
maxRows
)
{
pack
.
y
=
maskVal
;
pack
.
w
=
maskVal
;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template
<
typename
F
>
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
packed_qkv_t
*
ptr_output
,
int
maxRows
,
F
&&
funcPack
,
attn_half2_t
maskVal
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
static_assert
(
HEAD_DIM
==
WARP_N
);
packed_qkv_t
*
ptrlane
=
&
ptr_output
[((
warpId
*
WARP_M_TILES
+
0
)
*
WARP_N_TILES
+
0
)
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
m
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_N_TILES
>
([
&
]
<
int
n
>
()
ALWAYSINLINE
{
packed_qkv_t
pack
=
funcPack
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
mask
(
pack
,
bit_cast
<
uint32_t
>
(
maskVal
),
m
,
maxRows
-
warpId
*
WARP_M
);
store
(
&
ptrlane
[(
m
*
WARP_N_TILES
+
n
)
*
WARP_SIZE
],
pack
);
});
});
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
int
numBlocksQ
=
binfo
.
numBlocksN
/
3
;
const
bool
is_q
=
bn
<
numBlocksQ
;
const
bool
is_k
=
!
is_q
&&
bn
<
numBlocksQ
*
2
;
// bn is head_id (assume HEAD_DIM == WARP_N)
int
head_id
,
strideHead
;
if
(
is_q
)
{
head_id
=
bn
;
strideHead
=
args
.
strideHead_q
;
}
else
if
(
is_k
)
{
head_id
=
bn
-
numBlocksQ
;
strideHead
=
args
.
strideHead_k
;
}
else
{
head_id
=
bn
-
numBlocksQ
*
2
;
strideHead
=
args
.
strideHead_v
;
}
int
block_offset
=
head_id
*
strideHead
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_TILES
*
WARP_SIZE
;
int
maxRows
=
args
.
actualM
-
bm
*
BLOCK_M
;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if
(
is_q
)
{
apply
(
fpsum
,
args
.
out_q
+
block_offset
,
maxRows
,
pack_q
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
else
if
(
is_k
)
{
apply
(
fpsum
,
args
.
out_k
+
block_offset
,
maxRows
,
pack_k
,
attn_half2_t
(
NAN
,
NAN
));
}
else
{
apply
(
fpsum
,
args
.
out_v
+
block_offset
,
maxRows
,
pack_v
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
}
};
struct
EpilogueLiteLA
{
__device__
__forceinline__
static
packed_f32psum_t
mma_litela
(
packed_fpsum_t
k
,
packed_fpsum_t
v
,
packed_f32psum_t
psum
)
{
...
...
@@ -1874,6 +2161,62 @@ public:
);
}
};
template
<
typename
Epilogue
>
struct
test_epilogue_kernel
{
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
false
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
struct
Arguments
{
const
half_t
*
input
;
half_t
*
output
;
// aligned to BLOCK_M and BLOCK_N
int
M
,
N
;
int
actualM
,
actualN
;
typename
Epilogue
::
Arguments
argsEpilogue
;
};
__device__
__forceinline__
void
operator
()(
Arguments
args
)
{
const
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
m_offset
=
bm
*
BLOCK_M
+
warpId
*
WARP_M
;
const
int
n_offset
=
bn
*
BLOCK_N
;
extern
__shared__
uint8_t
shmem
[];
fpsum_warp
fpsum
;
load_act_to_fpsum
<
false
>
()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
);
Epilogue
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
args
.
argsEpilogue
);
EpilogueDefault
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
output
,
.
actualM
=
args
.
actualM
,
.
actualN
=
args
.
actualN
,
});
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
92ac7b40
...
...
@@ -2,7 +2,7 @@
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
template
<
typename
Config
,
bool
USE_FP4
>
class
GEMM_W4A4_Launch
{
using
GEMM
=
GEMM_W4A4
<
Config
>
;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
...
...
@@ -48,7 +48,11 @@ public:
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
Tensor
wcscales
,
// packed ws [N]
Tensor
out_q
,
// packed attention [B, H, M, D]
Tensor
out_k
,
// packed attention [B, H, M, D]
Tensor
out_v
,
// packed attention [B, H, M, D]
int
attn_tokens
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
src/kernels/zgemm/gemm_w4a4_launch_bf16.cu
→
src/kernels/zgemm/gemm_w4a4_launch_bf16
_fp4
.cu
View file @
92ac7b40
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
>;
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
true
>;
};
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu
0 → 100644
View file @
92ac7b40
#include "gemm_w4a4_launch_impl.cuh"
namespace
nunchaku
::
kernels
{
template
class
GEMM_W4A4_Launch
<
GEMMConfig_W4A4_BF16
,
false
>;
};
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment