Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
755d8be7
Commit
755d8be7
authored
Jan 25, 2026
by
zhanghj2
Browse files
适配combine kernel
parent
572946f5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
152 deletions
+156
-152
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+156
-152
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
755d8be7
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "params.h"
#include "params.h"
#include "utils.h"
#include "utils.h"
#define CUDART_L2E_F 1.442695041F
using
namespace
cute
;
using
namespace
cute
;
...
@@ -18,146 +19,146 @@ namespace smxx::decode {
...
@@ -18,146 +19,146 @@ namespace smxx::decode {
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
__global__
void
__launch_bounds__
(
NUM_THREADS
)
__global__
void
__launch_bounds__
(
NUM_THREADS
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
//
//
grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
//
//
Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
//
static_assert(NUM_THREADS/
32
== BLOCK_SIZE_M); // The number of warps == block_size_m
static_assert
(
NUM_THREADS
/
64
==
BLOCK_SIZE_M
);
// The number of warps == block_size_m
//
const int batch_idx = blockIdx.x;
const
int
batch_idx
=
blockIdx
.
x
;
//
const int s_q_idx = blockIdx.y;
const
int
s_q_idx
=
blockIdx
.
y
;
//
const int h_block_idx = blockIdx.z;
const
int
h_block_idx
=
blockIdx
.
z
;
//
const int warp_idx = threadIdx.x /
32
;
const
int
warp_idx
=
threadIdx
.
x
/
64
;
//
const int lane_idx = threadIdx.x %
32
;
const
int
lane_idx
=
threadIdx
.
x
%
64
;
//
int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
int
num_valid_heads
=
std
::
min
(
BLOCK_SIZE_M
,
params
.
h_q
-
BLOCK_SIZE_M
*
h_block_idx
);
//
if (warp_idx >= num_valid_heads) {
if
(
warp_idx
>=
num_valid_heads
)
{
//
return;
return
;
//
}
}
//
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const
int
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
//
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const
int
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
//
const int my_num_splits = end_split_idx - start_split_idx;
const
int
my_num_splits
=
end_split_idx
-
start_split_idx
;
//
if (my_num_splits == 1) {
if
(
my_num_splits
==
1
)
{
//
return;
return
;
//
}
}
//
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
FLASH_DEVICE_ASSERT
(
my_num_splits
<=
MAX_SPLITS
);
//
Tensor gLseAccum = make_tensor(
Tensor
gLseAccum
=
make_tensor
(
//
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
make_gmem_ptr
((
float
*
)
params
.
lse_accum
+
start_split_idx
*
params
.
stride_lse_accum_split
+
s_q_idx
*
params
.
stride_lse_accum_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
//
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
Shape
<
Int
<
MAX_SPLITS
>
,
Int
<
BLOCK_SIZE_M
>>
{},
//
make_stride(params.stride_lse_accum_split, _1{})
make_stride
(
params
.
stride_lse_accum_split
,
_1
{})
//
);
);
//
Tensor gLse = make_tensor(
Tensor
gLse
=
make_tensor
(
//
make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
make_gmem_ptr
((
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
s_q_idx
*
params
.
stride_lse_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
//
Shape<Int<BLOCK_SIZE_M>>{},
Shape
<
Int
<
BLOCK_SIZE_M
>>
{},
//
Stride<_1>{}
Stride
<
_1
>
{}
//
);
);
//
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
__shared__
float
smem_buf
[
BLOCK_SIZE_M
][
MAX_SPLITS
];
//
//
Wait for the previous kernel (the MLA kernel) to finish
// Wait for the previous kernel (the MLA kernel) to finish
// cudaGridDependencySynchronize();
// cudaGridDependencySynchronize();
//
//
Prefetch
// Prefetch
//
static_assert(HEAD_DIM_V % (
32
*4) == 0);
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
//
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (
32
*4);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
//
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
//
float4 datas[ELEMS_PER_THREAD];
float4
datas
[
ELEMS_PER_THREAD
];
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
//
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*
128
); // NOTE We don't use __ldg here since it is incompatible with PDL
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
lane_idx
*
4
+
i
*
256
);
// NOTE We don't use __ldg here since it is incompatible with PDL
//
}
}
//
//
Warp #i gathers LseAccum for seq #i
// Warp #i gathers LseAccum for seq #i
//
{
{
//
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS,
32
);
constexpr
int
NUM_LSE_PER_THREAD
=
cute
::
ceil_div
(
MAX_SPLITS
,
64
);
//
float local_lse[NUM_LSE_PER_THREAD];
float
local_lse
[
NUM_LSE_PER_THREAD
];
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
//
const int split_idx = i*
32
+ lane_idx;
const
int
split_idx
=
i
*
64
+
lane_idx
;
//
local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
local_lse
[
i
]
=
split_idx
<
my_num_splits
?
gLseAccum
(
split_idx
,
warp_idx
)
:
-
INFINITY
;
//
}
}
//
float max_lse = -INFINITY;
float
max_lse
=
-
INFINITY
;
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
//
max_lse = max(max_lse, local_lse[i]);
max_lse
=
max
(
max_lse
,
local_lse
[
i
]);
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int offset =
16
; offset >= 1; offset /= 2)
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
//
max_lse = max(max_lse, __shfl_xor
_sync(uint32_t(-1),
max_lse, offset));
max_lse
=
max
(
max_lse
,
__shfl_xor
(
max_lse
,
offset
));
//
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
max_lse
=
max_lse
==
-
INFINITY
?
0.0
f
:
max_lse
;
// In case all local LSEs are -inf
//
float sum_lse = 0;
float
sum_lse
=
0
;
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
//
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
sum_lse
=
sum_lse
+
exp2f
(
local_lse
[
i
]
-
max_lse
);
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int offset =
16
; offset >= 1; offset /= 2)
for
(
int
offset
=
32
;
offset
>=
1
;
offset
/=
2
)
//
sum_lse = sum_lse + __shfl_xor
_sync(uint32_t(-1),
sum_lse, offset);
sum_lse
=
sum_lse
+
__shfl_xor
(
sum_lse
,
offset
);
//
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
float
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
==
-
INFINITY
)
?
INFINITY
:
log2f
(
sum_lse
)
+
max_lse
;
//
if (lane_idx == 0)
if
(
lane_idx
==
0
)
//
gLse(warp_idx) = global_lse / (float)M_LOG2E;
gLse
(
warp_idx
)
=
global_lse
/
(
float
)
M_LOG2E
;
//
if (params.attn_sink != nullptr) {
if
(
params
.
attn_sink
!=
nullptr
)
{
//
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
int
q_head_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
//
float attn_sink = __ldg(params.attn_sink + q_head_idx);
float
attn_sink
=
__ldg
(
params
.
attn_sink
+
q_head_idx
);
//
if (global_lse != INFINITY) {
if
(
global_lse
!=
INFINITY
)
{
//
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
//
// If attn_sink is -inf, this has no effect on global_lse
// If attn_sink is -inf, this has no effect on global_lse
//
global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
global_lse
+=
log2f
(
1
+
exp2f
(
attn_sink
*
CUDART_L2E_F
-
global_lse
));
//
} else {
}
else
{
//
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
//
global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
global_lse
=
attn_sink
==
-
INFINITY
?
+
INFINITY
:
attn_sink
*
CUDART_L2E_F
;
//
}
}
//
}
}
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
for
(
int
i
=
0
;
i
<
NUM_LSE_PER_THREAD
;
++
i
)
{
//
const int split_idx = i*
32
+ lane_idx;
const
int
split_idx
=
i
*
64
+
lane_idx
;
//
smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
smem_buf
[
warp_idx
][
split_idx
]
=
exp2f
(
local_lse
[
i
]
-
global_lse
);
//
}
}
//
}
}
//
__sync
warp
();
__sync
threads
();
//
//
Warp #i accumulates activation for seq #i
// Warp #i accumulates activation for seq #i
//
{
{
//
float4 result[ELEMS_PER_THREAD];
float4
result
[
ELEMS_PER_THREAD
];
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
//
result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
result
[
i
]
=
{
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
//
#pragma unroll 1
#pragma unroll 1
//
for (int split = 0; split < my_num_splits; ++split) {
for
(
int
split
=
0
;
split
<
my_num_splits
;
++
split
)
{
//
float lse_scale = smem_buf[warp_idx][split];
float
lse_scale
=
smem_buf
[
warp_idx
][
split
];
//
// if (lse_scale != 0.f) {
// if (lse_scale != 0.f) {
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
//
result[i].x += lse_scale * datas[i].x;
result
[
i
].
x
+=
lse_scale
*
datas
[
i
].
x
;
//
result[i].y += lse_scale * datas[i].y;
result
[
i
].
y
+=
lse_scale
*
datas
[
i
].
y
;
//
result[i].z += lse_scale * datas[i].z;
result
[
i
].
z
+=
lse_scale
*
datas
[
i
].
z
;
//
result[i].w += lse_scale * datas[i].w;
result
[
i
].
w
+=
lse_scale
*
datas
[
i
].
w
;
//
if (split != my_num_splits-1) {
if
(
split
!=
my_num_splits
-
1
)
{
//
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*
128
);
datas
[
i
]
=
*
(
float4
*
)(
oaccum_ptr
+
(
split
+
1
)
*
params
.
stride_o_accum_split
+
lane_idx
*
4
+
i
*
256
);
//
}
}
//
}
}
//
// }
// }
//
}
}
//
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
const
int
h_q_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
//
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
;
//
CUTLASS_PRAGMA_UNROLL
CUTLASS_PRAGMA_UNROLL
//
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
//
float4 data = result[i];
float4
data
=
result
[
i
];
//
ElementT data_converted[4];
ElementT
data_converted
[
4
];
//
data_converted[0] = (ElementT)(data.x);
data_converted
[
0
]
=
(
ElementT
)(
data
.
x
);
//
data_converted[1] = (ElementT)(data.y);
data_converted
[
1
]
=
(
ElementT
)(
data
.
y
);
//
data_converted[2] = (ElementT)(data.z);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
//
data_converted[3] = (ElementT)(data.w);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
//
static_assert(sizeof(ElementT) == 2);
static_assert
(
sizeof
(
ElementT
)
==
2
);
//
*(uint64_t*)(o_ptr + lane_idx*4 + i*
128
) = *(uint64_t*)data_converted;
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
4
+
i
*
256
)
=
*
(
uint64_t
*
)
data_converted
;
//
}
}
//
}
}
}
}
...
@@ -188,26 +189,29 @@ template<typename ElementT>
...
@@ -188,26 +189,29 @@ template<typename ElementT>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
// MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
// constexpr int BLOCK_SIZE_M = 8;
constexpr
int
BLOCK_SIZE_M
=
4
;
// constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
// constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
// auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// dim3(NUM_THREADS, 1, 1),
// 0,
// 0,
// params.stream,
// params.stream,
// attribute,
// attribute,
// 1
// 1
// };
// };
// CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
// });
dim3
(
NUM_THREADS
,
1
,
1
),
smem_size
,
params
.
stream
>>>
(
params
);
});
CHECK_CUDA_KERNEL_LAUNCH
();
CHECK_CUDA_KERNEL_LAUNCH
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment