Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
200f01d5
"vscode:/vscode.git/clone" did not exist on "add95438dfea684bac387fe1cf48e8bdd0c482d8"
Commit
200f01d5
authored
Jan 26, 2026
by
zhanghj2
Browse files
支持attn_sink
parent
9b54b03c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
31 deletions
+46
-31
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+46
-31
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
200f01d5
...
...
@@ -17,7 +17,7 @@ using namespace cute;
namespace
smxx
::
decode
{
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
__global__
void
__launch_bounds__
(
NUM_THREADS
)
__global__
void
__launch_bounds__
(
NUM_THREADS
,
1
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
...
...
@@ -54,20 +54,8 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
);
__shared__
float
smem_buf
[
BLOCK_SIZE_M
][
MAX_SPLITS
];
// Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize();
// Prefetch
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
}
// __syncthreads();
// Warp #i gathers LseAccum for seq #i
{
constexpr
int
NUM_LSE_PER_THREAD
=
cute
::
ceil_div
(
MAX_SPLITS
,
64
);
...
...
@@ -90,36 +78,50 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
float
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
sum_lse
=
sum_lse
+
exp2f
(
local_lse
[
i
]
-
max_lse
);
sum_lse
=
sum_lse
+
__builtin_amdgcn_
exp2f
(
(
local_lse
[
i
]
-
max_lse
)
*
CUDART_L2E_F
)
;
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
sum_lse
=
sum_lse
+
__shfl_xor
(
sum_lse
,
offset
);
float
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
==
-
INFINITY
)
?
INFINITY
:
log
2
f
(
sum_lse
)
+
max_lse
;
float
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
==
-
INFINITY
)
?
INFINITY
:
logf
(
sum_lse
)
+
max_lse
;
if
(
lane_idx
==
0
)
gLse
(
warp_idx
)
=
global_lse
/
(
float
)
M_LOG2E
;
gLse
(
warp_idx
)
=
global_lse
;
float
o_scale
=
0.0
f
;
if
(
params
.
attn_sink
!=
nullptr
)
{
int
q_head_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
float
attn_sink
=
__ldg
(
params
.
attn_sink
+
q_head_idx
);
if
(
global_lse
!=
INFINITY
)
{
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse
global_lse
+=
log2f
(
1
+
exp2f
(
attn_sink
*
CUDART_L2E_F
-
global_lse
));
}
else
{
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse
=
attn_sink
==
-
INFINITY
?
+
INFINITY
:
attn_sink
*
CUDART_L2E_F
;
}
float
Attn_sink_exp2
=
__builtin_amdgcn_exp2f
(
attn_sink
*
CUDART_L2E_F
);
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
global_lse
*
CUDART_L2E_F
);
o_scale
=
lse_exp2
/
(
lse_exp2
+
Attn_sink_exp2
);
// if (global_lse != INFINITY) {
// // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// // If attn_sink is -inf, this has no effect on global_lse
// global_lse += log2f(1 + __builtin_amdgcn_exp2f(attn_sink*CUDART_L2E_F - global_lse));
// } else {
// // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
// global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
// }
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
const
int
split_idx
=
i
*
64
+
lane_idx
;
smem_buf
[
warp_idx
][
split_idx
]
=
exp2f
(
local_lse
[
i
]
-
global_lse
);
if
(
split_idx
<
my_num_splits
)
{
// printf("local_lse %.2f global_lse = %.2f \n", local_lse[i], global_lse);
smem_buf
[
warp_idx
][
split_idx
]
=
__builtin_amdgcn_exp2f
((
local_lse
[
i
]
-
global_lse
)
*
CUDART_L2E_F
)
*
o_scale
;
}
}
}
__syncthreads
();
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
}
// Warp #i accumulates activation for seq #i
{
float4
result
[
ELEMS_PER_THREAD
];
...
...
@@ -130,6 +132,10 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
#pragma unroll 1
for
(
int
split
=
0
;
split
<
my_num_splits
;
++
split
)
{
float
lse_scale
=
smem_buf
[
warp_idx
][
split
];
// if (warp_idx == 2 && threadIdx.x == 128)
// {
// printf("threadIdx.x = %d %.3f %.3f lse_scale = %.2f \n",threadIdx.x, datas[0].x, datas[1].x, lse_scale);
// }
// if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
...
...
@@ -143,7 +149,10 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
}
// }
}
// if (warp_idx == 2)
// {
// printf(" %.3f \n", result[0].x);
// }
const
int
h_q_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
;
...
...
@@ -151,6 +160,12 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
float4
data
=
result
[
i
];
ElementT
data_converted
[
4
];
// auto res = __builtin_hcu_cvt_pk_bf16_f32(0, data.x, 0, data.y, 0);
// data_converted[0].storage = res[0];
// data_converted[1].storage = res[1];
// res = __builtin_hcu_cvt_pk_bf16_f32(0, data.z, 0, data.w, 0);
// data_converted[2].storage = res[0];
// data_converted[3].storage = res[1];
data_converted
[
0
]
=
(
ElementT
)(
data
.
x
);
data_converted
[
1
]
=
(
ElementT
)(
data
.
y
);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
...
...
@@ -208,7 +223,7 @@ void run_flash_mla_combine_kernel(CombineParams ¶ms) {
// 1
// };
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
dim3
(
NUM_THREADS
,
1
,
1
),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment