Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
755d8be7
Commit
755d8be7
authored
Jan 25, 2026
by
zhanghj2
Browse files
适配combine kernel
parent
572946f5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
152 deletions
+156
-152
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+156
-152
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
755d8be7
...
...
@@ -10,6 +10,7 @@
#include "params.h"
#include "utils.h"
#define CUDART_L2E_F 1.442695041F
using
namespace
cute
;
...
...
@@ -18,146 +19,146 @@ namespace smxx::decode {
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
__global__
void
__launch_bounds__
(
NUM_THREADS
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
//
//
grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
//
//
Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
//
static_assert(NUM_THREADS/
32
== BLOCK_SIZE_M); // The number of warps == block_size_m
//
const int batch_idx = blockIdx.x;
//
const int s_q_idx = blockIdx.y;
//
const int h_block_idx = blockIdx.z;
//
const int warp_idx = threadIdx.x /
32
;
//
const int lane_idx = threadIdx.x %
32
;
//
int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
//
if (warp_idx >= num_valid_heads) {
//
return;
//
}
//
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
//
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
//
const int my_num_splits = end_split_idx - start_split_idx;
//
if (my_num_splits == 1) {
//
return;
//
}
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert
(
NUM_THREADS
/
64
==
BLOCK_SIZE_M
);
// The number of warps == block_size_m
const
int
batch_idx
=
blockIdx
.
x
;
const
int
s_q_idx
=
blockIdx
.
y
;
const
int
h_block_idx
=
blockIdx
.
z
;
const
int
warp_idx
=
threadIdx
.
x
/
64
;
const
int
lane_idx
=
threadIdx
.
x
%
64
;
int
num_valid_heads
=
std
::
min
(
BLOCK_SIZE_M
,
params
.
h_q
-
BLOCK_SIZE_M
*
h_block_idx
);
if
(
warp_idx
>=
num_valid_heads
)
{
return
;
}
const
int
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
const
int
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
const
int
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
//
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
FLASH_DEVICE_ASSERT
(
my_num_splits
<=
MAX_SPLITS
);
//
Tensor gLseAccum = make_tensor(
//
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
//
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
//
make_stride(params.stride_lse_accum_split, _1{})
//
);
//
Tensor gLse = make_tensor(
//
make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
//
Shape<Int<BLOCK_SIZE_M>>{},
//
Stride<_1>{}
//
);
Tensor
gLseAccum
=
make_tensor
(
make_gmem_ptr
((
float
*
)
params
.
lse_accum
+
start_split_idx
*
params
.
stride_lse_accum_split
+
s_q_idx
*
params
.
stride_lse_accum_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
Shape
<
Int
<
MAX_SPLITS
>
,
Int
<
BLOCK_SIZE_M
>>
{},
make_stride
(
params
.
stride_lse_accum_split
,
_1
{})
);
Tensor
gLse
=
make_tensor
(
make_gmem_ptr
((
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
s_q_idx
*
params
.
stride_lse_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
Shape
<
Int
<
BLOCK_SIZE_M
>>
{},
Stride
<
_1
>
{}
);
//
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
__shared__
float
smem_buf
[
BLOCK_SIZE_M
][
MAX_SPLITS
];
//
//
Wait for the previous kernel (the MLA kernel) to finish
// Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize();
//
//
Prefetch
//
static_assert(HEAD_DIM_V % (
32
*4) == 0);
//
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (
32
*4);
//
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
//
float4 datas[ELEMS_PER_THREAD];
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
//
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*
128
); // NOTE We don't use __ldg here since it is incompatible with PDL
//
}
//
//
Warp #i gathers LseAccum for seq #i
//
{
//
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS,
32
);
//
float local_lse[NUM_LSE_PER_THREAD];
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
//
const int split_idx = i*
32
+ lane_idx;
//
local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
//
}
//
float max_lse = -INFINITY;
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
//
max_lse = max(max_lse, local_lse[i]);
//
CUTLASS_PRAGMA_UNROLL
//
for (int offset =
16
; offset >= 1; offset /= 2)
//
max_lse = max(max_lse, __shfl_xor
_sync(uint32_t(-1),
max_lse, offset));
//
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
//
float sum_lse = 0;
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
//
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
//
CUTLASS_PRAGMA_UNROLL
//
for (int offset =
16
; offset >= 1; offset /= 2)
//
sum_lse = sum_lse + __shfl_xor
_sync(uint32_t(-1),
sum_lse, offset);
//
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
//
if (lane_idx == 0)
//
gLse(warp_idx) = global_lse / (float)M_LOG2E;
// Prefetch
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
}
// Warp #i gathers LseAccum for seq #i
{
constexpr
int
NUM_LSE_PER_THREAD
=
cute
::
ceil_div
(
MAX_SPLITS
,
64
);
float
local_lse
[
NUM_LSE_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
const
int
split_idx
=
i
*
64
+
lane_idx
;
local_lse
[
i
]
=
split_idx
<
my_num_splits
?
gLseAccum
(
split_idx
,
warp_idx
)
:
-
INFINITY
;
}
float
max_lse
=
-
INFINITY
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
max_lse
=
max
(
max_lse
,
local_lse
[
i
]);
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
max_lse
=
max
(
max_lse
,
__shfl_xor
(
max_lse
,
offset
));
max_lse
=
max_lse
==
-
INFINITY
?
0.0
f
:
max_lse
;
// In case all local LSEs are -inf
float
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
sum_lse
=
sum_lse
+
exp2f
(
local_lse
[
i
]
-
max_lse
);
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
sum_lse
=
sum_lse
+
__shfl_xor
(
sum_lse
,
offset
);
float
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
==
-
INFINITY
)
?
INFINITY
:
log2f
(
sum_lse
)
+
max_lse
;
if
(
lane_idx
==
0
)
gLse
(
warp_idx
)
=
global_lse
/
(
float
)
M_LOG2E
;
//
if (params.attn_sink != nullptr) {
//
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
//
float attn_sink = __ldg(params.attn_sink + q_head_idx);
//
if (global_lse != INFINITY) {
//
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
//
// If attn_sink is -inf, this has no effect on global_lse
//
global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
//
} else {
//
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
//
global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
//
}
//
}
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
//
const int split_idx = i*
32
+ lane_idx;
//
smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
//
}
//
}
//
__sync
warp
();
//
//
Warp #i accumulates activation for seq #i
//
{
//
float4 result[ELEMS_PER_THREAD];
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
//
result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
//
#pragma unroll 1
//
for (int split = 0; split < my_num_splits; ++split) {
//
float lse_scale = smem_buf[warp_idx][split];
//
// if (lse_scale != 0.f) {
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
//
result[i].x += lse_scale * datas[i].x;
//
result[i].y += lse_scale * datas[i].y;
//
result[i].z += lse_scale * datas[i].z;
//
result[i].w += lse_scale * datas[i].w;
//
if (split != my_num_splits-1) {
//
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*
128
);
//
}
//
}
//
// }
//
}
if
(
params
.
attn_sink
!=
nullptr
)
{
int
q_head_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
float
attn_sink
=
__ldg
(
params
.
attn_sink
+
q_head_idx
);
if
(
global_lse
!=
INFINITY
)
{
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse
global_lse
+=
log2f
(
1
+
exp2f
(
attn_sink
*
CUDART_L2E_F
-
global_lse
));
}
else
{
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse
=
attn_sink
==
-
INFINITY
?
+
INFINITY
:
attn_sink
*
CUDART_L2E_F
;
}
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
const
int
split_idx
=
i
*
64
+
lane_idx
;
smem_buf
[
warp_idx
][
split_idx
]
=
exp2f
(
local_lse
[
i
]
-
global_lse
);
}
}
__sync
threads
();
// Warp #i accumulates activation for seq #i
{
float4
result
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
result
[
i
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
#pragma unroll 1
for
(
int
split
=
0
;
split
<
my_num_splits
;
++
split
)
{
float
lse_scale
=
smem_buf
[
warp_idx
][
split
];
// if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
result
[
i
].
x
+=
lse_scale
*
datas
[
i
].
x
;
result
[
i
].
y
+=
lse_scale
*
datas
[
i
].
y
;
result
[
i
].
z
+=
lse_scale
*
datas
[
i
].
z
;
result
[
i
].
w
+=
lse_scale
*
datas
[
i
].
w
;
if
(
split
!=
my_num_splits
-
1
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
(
split
+
1
)
*
params
.
stride_o_accum_split
+
lane_idx
*
4
+
i
*
256
);
}
}
// }
}
//
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
//
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
//
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
//
float4 data = result[i];
//
ElementT data_converted[4];
//
data_converted[0] = (ElementT)(data.x);
//
data_converted[1] = (ElementT)(data.y);
//
data_converted[2] = (ElementT)(data.z);
//
data_converted[3] = (ElementT)(data.w);
//
static_assert(sizeof(ElementT) == 2);
//
*(uint64_t*)(o_ptr + lane_idx*4 + i*
128
) = *(uint64_t*)data_converted;
//
}
//
}
const
int
h_q_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
float4
data
=
result
[
i
];
ElementT
data_converted
[
4
];
data_converted
[
0
]
=
(
ElementT
)(
data
.
x
);
data_converted
[
1
]
=
(
ElementT
)(
data
.
y
);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
static_assert
(
sizeof
(
ElementT
)
==
2
);
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
4
+
i
*
256
)
=
*
(
uint64_t
*
)
data_converted
;
}
}
}
...
...
@@ -188,26 +189,29 @@ template<typename ElementT>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
// MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
// constexpr int BLOCK_SIZE_M = 8;
// constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
// constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
// auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
// CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
// });
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
dim3
(
NUM_THREADS
,
1
,
1
),
smem_size
,
params
.
stream
>>>
(
params
);
});
CHECK_CUDA_KERNEL_LAUNCH
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment