Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
755d8be7
"csrc/vscode:/vscode.git/clone" did not exist on "2d0cf41dd1900b105d74cb071f4cac35e3fb6f47"
Commit
755d8be7
authored
Jan 25, 2026
by
zhanghj2
Browse files
适配combine kernel
parent
572946f5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
152 deletions
+156
-152
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+156
-152
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
755d8be7
...
...
@@ -10,6 +10,7 @@
#include "params.h"
#include "utils.h"
#define CUDART_L2E_F 1.442695041F
using
namespace
cute
;
...
...
@@ -18,146 +19,146 @@ namespace smxx::decode {
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
__global__
void
__launch_bounds__
(
NUM_THREADS
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
//
//
grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
//
//
Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
//
static_assert(NUM_THREADS/
32
== BLOCK_SIZE_M); // The number of warps == block_size_m
//
const int batch_idx = blockIdx.x;
//
const int s_q_idx = blockIdx.y;
//
const int h_block_idx = blockIdx.z;
//
const int warp_idx = threadIdx.x /
32
;
//
const int lane_idx = threadIdx.x %
32
;
//
int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
//
if (warp_idx >= num_valid_heads) {
//
return;
//
}
//
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
//
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
//
const int my_num_splits = end_split_idx - start_split_idx;
//
if (my_num_splits == 1) {
//
return;
//
}
//
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
//
Tensor gLseAccum = make_tensor(
//
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
//
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
//
make_stride(params.stride_lse_accum_split, _1{})
//
);
//
Tensor gLse = make_tensor(
//
make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
//
Shape<Int<BLOCK_SIZE_M>>{},
//
Stride<_1>{}
//
);
//
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
//
//
Wait for the previous kernel (the MLA kernel) to finish
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert
(
NUM_THREADS
/
64
==
BLOCK_SIZE_M
);
// The number of warps == block_size_m
const
int
batch_idx
=
blockIdx
.
x
;
const
int
s_q_idx
=
blockIdx
.
y
;
const
int
h_block_idx
=
blockIdx
.
z
;
const
int
warp_idx
=
threadIdx
.
x
/
64
;
const
int
lane_idx
=
threadIdx
.
x
%
64
;
int
num_valid_heads
=
std
::
min
(
BLOCK_SIZE_M
,
params
.
h_q
-
BLOCK_SIZE_M
*
h_block_idx
);
if
(
warp_idx
>=
num_valid_heads
)
{
return
;
}
const
int
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
const
int
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
const
int
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
FLASH_DEVICE_ASSERT
(
my_num_splits
<=
MAX_SPLITS
);
Tensor
gLseAccum
=
make_tensor
(
make_gmem_ptr
((
float
*
)
params
.
lse_accum
+
start_split_idx
*
params
.
stride_lse_accum_split
+
s_q_idx
*
params
.
stride_lse_accum_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
Shape
<
Int
<
MAX_SPLITS
>
,
Int
<
BLOCK_SIZE_M
>>
{},
make_stride
(
params
.
stride_lse_accum_split
,
_1
{})
);
Tensor
gLse
=
make_tensor
(
make_gmem_ptr
((
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
s_q_idx
*
params
.
stride_lse_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
Shape
<
Int
<
BLOCK_SIZE_M
>>
{},
Stride
<
_1
>
{}
);
__shared__
float
smem_buf
[
BLOCK_SIZE_M
][
MAX_SPLITS
];
// Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize();
// // Prefetch
// static_assert(HEAD_DIM_V % (32*4) == 0);
// constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);
// float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
// float4 datas[ELEMS_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL
// }
// // Warp #i gathers LseAccum for seq #i
// {
// constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
// float local_lse[NUM_LSE_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
// const int split_idx = i*32 + lane_idx;
// local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
// }
// float max_lse = -INFINITY;
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
// max_lse = max(max_lse, local_lse[i]);
// CUTLASS_PRAGMA_UNROLL
// for (int offset = 16; offset >= 1; offset /= 2)
// max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
// max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
// float sum_lse = 0;
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
// sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
// CUTLASS_PRAGMA_UNROLL
// for (int offset = 16; offset >= 1; offset /= 2)
// sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
// float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
// if (lane_idx == 0)
// gLse(warp_idx) = global_lse / (float)M_LOG2E;
// if (params.attn_sink != nullptr) {
// int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
// float attn_sink = __ldg(params.attn_sink + q_head_idx);
// if (global_lse != INFINITY) {
// // If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// // If attn_sink is -inf, this has no effect on global_lse
// global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
// } else {
// // We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
// global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
// }
// }
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
// const int split_idx = i*32 + lane_idx;
// smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
// }
// }
// __syncwarp();
// // Warp #i accumulates activation for seq #i
// {
// float4 result[ELEMS_PER_THREAD];
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i)
// result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
// #pragma unroll 1
// for (int split = 0; split < my_num_splits; ++split) {
// float lse_scale = smem_buf[warp_idx][split];
// // if (lse_scale != 0.f) {
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// result[i].x += lse_scale * datas[i].x;
// result[i].y += lse_scale * datas[i].y;
// result[i].z += lse_scale * datas[i].z;
// result[i].w += lse_scale * datas[i].w;
// if (split != my_num_splits-1) {
// datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);
// }
// }
// // }
// }
// const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
// ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
// CUTLASS_PRAGMA_UNROLL
// for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
// float4 data = result[i];
// ElementT data_converted[4];
// data_converted[0] = (ElementT)(data.x);
// data_converted[1] = (ElementT)(data.y);
// data_converted[2] = (ElementT)(data.z);
// data_converted[3] = (ElementT)(data.w);
// static_assert(sizeof(ElementT) == 2);
// *(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;
// }
// Prefetch
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
}
// Warp #i gathers LseAccum for seq #i
{
constexpr
int
NUM_LSE_PER_THREAD
=
cute
::
ceil_div
(
MAX_SPLITS
,
64
);
float
local_lse
[
NUM_LSE_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
const
int
split_idx
=
i
*
64
+
lane_idx
;
local_lse
[
i
]
=
split_idx
<
my_num_splits
?
gLseAccum
(
split_idx
,
warp_idx
)
:
-
INFINITY
;
}
float
max_lse
=
-
INFINITY
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
max_lse
=
max
(
max_lse
,
local_lse
[
i
]);
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
max_lse
=
max
(
max_lse
,
__shfl_xor
(
max_lse
,
offset
));
max_lse
=
max_lse
==
-
INFINITY
?
0.0
f
:
max_lse
;
// In case all local LSEs are -inf
float
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
sum_lse
=
sum_lse
+
exp2f
(
local_lse
[
i
]
-
max_lse
);
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
sum_lse
=
sum_lse
+
__shfl_xor
(
sum_lse
,
offset
);
float
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
==
-
INFINITY
)
?
INFINITY
:
log2f
(
sum_lse
)
+
max_lse
;
if
(
lane_idx
==
0
)
gLse
(
warp_idx
)
=
global_lse
/
(
float
)
M_LOG2E
;
if
(
params
.
attn_sink
!=
nullptr
)
{
int
q_head_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
float
attn_sink
=
__ldg
(
params
.
attn_sink
+
q_head_idx
);
if
(
global_lse
!=
INFINITY
)
{
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse
global_lse
+=
log2f
(
1
+
exp2f
(
attn_sink
*
CUDART_L2E_F
-
global_lse
));
}
else
{
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse
=
attn_sink
==
-
INFINITY
?
+
INFINITY
:
attn_sink
*
CUDART_L2E_F
;
}
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
const
int
split_idx
=
i
*
64
+
lane_idx
;
smem_buf
[
warp_idx
][
split_idx
]
=
exp2f
(
local_lse
[
i
]
-
global_lse
);
}
}
__syncthreads
();
// Warp #i accumulates activation for seq #i
{
float4
result
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
result
[
i
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
#pragma unroll 1
for
(
int
split
=
0
;
split
<
my_num_splits
;
++
split
)
{
float
lse_scale
=
smem_buf
[
warp_idx
][
split
];
// if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
result
[
i
].
x
+=
lse_scale
*
datas
[
i
].
x
;
result
[
i
].
y
+=
lse_scale
*
datas
[
i
].
y
;
result
[
i
].
z
+=
lse_scale
*
datas
[
i
].
z
;
result
[
i
].
w
+=
lse_scale
*
datas
[
i
].
w
;
if
(
split
!=
my_num_splits
-
1
)
{
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
(
split
+
1
)
*
params
.
stride_o_accum_split
+
lane_idx
*
4
+
i
*
256
);
}
}
// }
}
const
int
h_q_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
float4
data
=
result
[
i
];
ElementT
data_converted
[
4
];
data_converted
[
0
]
=
(
ElementT
)(
data
.
x
);
data_converted
[
1
]
=
(
ElementT
)(
data
.
y
);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
static_assert
(
sizeof
(
ElementT
)
==
2
);
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
4
+
i
*
256
)
=
*
(
uint64_t
*
)
data_converted
;
}
}
}
...
...
@@ -188,11 +189,11 @@ template<typename ElementT>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
//
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
//
constexpr int BLOCK_SIZE_M =
8
;
//
constexpr int NUM_THREADS = BLOCK_SIZE_M*
32
;
//
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
//
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
...
...
@@ -206,8 +207,11 @@ void run_flash_mla_combine_kernel(CombineParams ¶ms) {
// attribute,
// 1
// };
// CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
// });
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
dim3
(
NUM_THREADS
,
1
,
1
),
smem_size
,
params
.
stream
>>>
(
params
);
});
CHECK_CUDA_KERNEL_LAUNCH
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment