Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
0728420c
Commit
0728420c
authored
May 22, 2026
by
zhanghj2
Browse files
优化sparse prefill
parent
1cb8a563
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
137 deletions
+73
-137
csrc/gfx93/prefill/sparse/config.h
csrc/gfx93/prefill/sparse/config.h
+1
-1
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+72
-136
No files found.
csrc/gfx93/prefill/sparse/config.h
View file @
0728420c
...
...
@@ -124,7 +124,7 @@ static void run(const SparseAttnFwdParams ¶ms);
};
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
,
bool
USE_ATTN_SINK
=
false
,
bool
CACHE_INDICES_IN_LDS
=
false
>
class
KernelTemplate_B_H_64
{
public:
...
...
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
0728420c
...
...
@@ -10,8 +10,8 @@ namespace gfx93::fwd {
using
namespace
cute
;
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
__device__
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
,
bool
USE_ATTN_SINK
,
bool
CACHE_INDICES_IN_LDS
>
__device__
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
,
USE_ATTN_SINK
,
CACHE_INDICES_IN_LDS
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
const
int
tidx
=
threadIdx
.
x
;
static
constexpr
int
kBlockM
=
B_H
;
...
...
@@ -96,7 +96,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
if
constexpr
(
IS_TOPK_2048
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
...
...
@@ -109,7 +109,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int col = lane_idx % 4;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
// int col_offset = col * 8 + warp_idx * 32;
if
constexpr
(
IS_TOPK_2048
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
...
...
@@ -132,8 +132,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
global_addr_q
[
2
]
=
64
;
global_addr_q
[
3
]
=
0x00020000
;
auto
buffer_load_lds_indices
=
[
&
]
(
int
n
)
{
if
constexpr
(
IS_TOPK_2048
)
{
auto
buffer_load_lds_indices
=
[
&
]
(
int
n
,
int
num_indices
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
PtrWrapper
glob_ptr_indices
;
*
(
uint64_t
*
)
&
glob_ptr_indices
=
reinterpret_cast
<
uint64_t
>
(
gIndices
);
glob_ptr_indices
.
latter
|=
0x40000000
;
...
...
@@ -146,18 +146,21 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
sIndices
)
+
warp_idx
*
64
*
4
*
4
;
const
int
offset_v
=
lane_idx
*
4
*
4
+
warp_idx
*
64
*
4
*
4
;
const
int
offset_s
=
n
*
1024
*
4
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_indices
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
const
int
first_index
=
warp_idx
*
256
+
lane_idx
*
4
;
if
(
first_index
<
num_indices
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_indices
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
if
constexpr
(
IS_TOPK_2048
)
{
buffer_load_lds_indices
(
0
);
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
buffer_load_lds_indices
(
0
,
IS_TOPK_2048
?
1024
:
params
.
topk
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -323,71 +326,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
#if 1
if
constexpr
(
D_QK
==
512
)
{
#define LOAD_K_AND_QK_GEMM_512(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
constexpr
int
k_val
=
15
;
buffer_load_lds_k
(
row_offset
,
col
,
k_val
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
1
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
2
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
3
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
>
(
q_reg
[
k_val
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
LOAD_K_AND_QK_GEMM_512
(
14
);
LOAD_K_AND_QK_GEMM_512
(
13
);
LOAD_K_AND_QK_GEMM_512
(
12
);
LOAD_K_AND_QK_GEMM_512
(
11
);
LOAD_K_AND_QK_GEMM_512
(
10
);
LOAD_K_AND_QK_GEMM_512
(
9
);
LOAD_K_AND_QK_GEMM_512
(
8
);
LOAD_K_AND_QK_GEMM_512
(
7
);
LOAD_K_AND_QK_GEMM_512
(
6
);
LOAD_K_AND_QK_GEMM_512
(
5
);
LOAD_K_AND_QK_GEMM_512
(
4
);
LOAD_K_AND_QK_GEMM_512
(
3
);
flash
::
qk_gemm
<
Element
,
2
>
(
q_reg
[
2
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
1
>
(
q_reg
[
1
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
0
>
(
q_reg
[
0
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
#undef LOAD_K_AND_QK_GEMM_512
}
else
{
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier
(
0
);
\
if constexpr (k_val < kQkChunks - 1) { \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
} \
}
{
constexpr
int
k_val
=
(
17
)
;
constexpr
int
k_val
=
kQkChunks
-
1
;
buffer_load_lds_k
(
row_offset
,
col
,
k_val
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
1
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
2
);
...
...
@@ -415,23 +367,22 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
LOAD_K_AND_QK_GEMM
(
4
);
LOAD_K_AND_QK_GEMM
(
3
);
flash
::
qk_gemm
<
Element
,
k_val
-
15
>
(
q_reg
[
k_val
-
15
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
flash
::
qk_gemm
<
Element
,
2
>
(
q_reg
[
2
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
16
>
(
q_reg
[
k_val
-
16
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
flash
::
qk_gemm
<
Element
,
1
>
(
q_reg
[
1
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
17
>
(
q_reg
[
k_val
-
17
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
flash
::
qk_gemm
<
Element
,
0
>
(
q_reg
[
0
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#undef LOAD_K_AND_QK_GEMM
}
#else
#define LOAD_K_AND_QK_GEMM(k) \
{ \
...
...
@@ -495,7 +446,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
t
;
if
constexpr
(
IS_TOPK_2048
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
t
=
sIndices
[
offs
%
1024
];
}
else
{
t
=
gIndices
[
offs
];
...
...
@@ -724,7 +675,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
process_one_block
(
block_idx
,
IsOtherBlock
{});
}
buffer_load_lds_indices
(
1
);
buffer_load_lds_indices
(
1
,
1024
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -765,18 +716,23 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
float
*
gMax_logits
=
reinterpret_cast
<
float
*>
(
params
.
max_logits
)
+
row_offset_lse
;
float
attn_sink_o_scale
=
1.0
f
;
if
constexpr
(
D_QK
==
512
&&
HAVE_TOPK_LENGTH
)
{
if
(
params
.
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
bidh
*
kBlockM
+
lane_idx
%
16
+
warp_idx
*
16
);
if
(
flash
::
is_positive_infinity
(
rAttn_sink
))
{
attn_sink_o_scale
=
0.0
f
;
}
else
if
(
!
flash
::
is_positive_infinity
(
lse
(
0
)))
{
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
0
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
attn_sink_o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
}
if
constexpr
(
USE_ATTN_SINK
)
{
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
bidh
*
kBlockM
+
lane_idx
%
16
+
warp_idx
*
16
);
if
(
flash
::
is_positive_infinity
(
rAttn_sink
))
{
attn_sink_o_scale
=
0.0
f
;
}
else
if
(
!
flash
::
is_positive_infinity
(
lse
(
0
)))
{
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
0
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
attn_sink_o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
}
}
auto
maybe_apply_attn_sink
=
[
&
]
(
float
value
)
->
float
{
if
constexpr
(
USE_ATTN_SINK
)
{
return
value
*
attn_sink_o_scale
;
}
else
{
return
value
;
}
};
{
// store O and gLSE
...
...
@@ -792,13 +748,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#if defined(__gfx938__)
Bf16_storage
res
;
col
=
(
lane_idx
/
16
)
*
8
+
ni
*
32
;
res
.
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
0
]
*
attn_sink_o_scale
,
0
,
acco_f32
[
ni
*
2
+
1
][
0
]
*
attn_sink_o_scale
,
0
);
res
.
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
][
0
]),
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
+
1
][
0
]
)
,
0
);
res
.
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
1
]
*
attn_sink_o_scale
,
0
,
acco_f32
[
ni
*
2
+
1
][
1
]
*
attn_sink_o_scale
,
0
);
res
.
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
][
1
]),
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
+
1
][
1
]
)
,
0
);
res
.
data_32
[
2
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
2
]
*
attn_sink_o_scale
,
0
,
acco_f32
[
ni
*
2
+
1
][
2
]
*
attn_sink_o_scale
,
0
);
res
.
data_32
[
2
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
][
2
]),
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
+
1
][
2
]
)
,
0
);
res
.
data_32
[
3
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
3
]
*
attn_sink_o_scale
,
0
,
acco_f32
[
ni
*
2
+
1
][
3
]
*
attn_sink_o_scale
,
0
);
res
.
data_32
[
3
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
][
3
]),
0
,
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
+
1
][
3
]
)
,
0
);
*
(
__fp16x8_t
*
)(
&
gO
(
row
,
col
))
=
res
.
data_128
;
...
...
@@ -809,8 +765,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
result_type
res
;
Element
e0
,
e1
;
e0
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
][
ei
]
*
attn_sink_o_scale
);
e1
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
+
1
][
ei
]
*
attn_sink_o_scale
);
e0
.
storage
=
float2bf16
(
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
][
ei
])
);
e1
.
storage
=
float2bf16
(
maybe_apply_attn_sink
(
acco_f32
[
ni
*
2
+
1
][
ei
]
)
);
res
[
0
]
=
e0
;
res
[
1
]
=
e1
;
// gO(row, col) = res[0];
...
...
@@ -1372,61 +1328,41 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
KU_CHECK_KERNEL_LAUNCH
();
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>::
run
(
const
SparseAttnFwdParams
&
params
)
{
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
,
bool
USE_ATTN_SINK
,
bool
CACHE_INDICES_IN_LDS
>
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
,
USE_ATTN_SINK
,
CACHE_INDICES_IN_LDS
>::
run
(
const
SparseAttnFwdParams
&
params
)
{
KU_ASSERT
(
params
.
h_kv
==
1
);
// KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT
(
params
.
topk
>
0
);
// KU_ASSERT(params.h_q % B_H == 0);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>>
;
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
,
USE_ATTN_SINK
,
CACHE_INDICES_IN_LDS
>>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
// 做了lds复用
dim3
grid
((
params
.
h_q
+
B_H
-
1
)
/
B_H
,
params
.
s_q
,
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
}
class
KernelTemplate_D512_H64_TopkLen_AttnSink
{
public:
static
constexpr
int
NUM_THREADS
=
KernelTemplate_B_H_64
<
512
,
true
,
false
>::
NUM_THREADS
;
static
__device__
__forceinline__
void
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
KernelTemplate_B_H_64
<
512
,
true
,
false
>::
devfunc
(
params
);
}
static
void
run
(
const
SparseAttnFwdParams
&
params
)
{
KU_ASSERT
(
params
.
h_kv
==
1
);
KU_ASSERT
(
params
.
topk
>
0
);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate_D512_H64_TopkLen_AttnSink
>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
dim3
grid
((
params
.
h_q
+
64
-
1
)
/
64
,
params
.
s_q
,
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
USE_ATTN_SINK
>
static
void
run_h64_fast_path
(
const
SparseAttnFwdParams
&
params
)
{
if
(
params
.
topk
==
2048
)
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
,
USE_ATTN_SINK
,
false
>::
run
(
params
);
}
else
if
(
params
.
topk
<=
1024
)
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
false
,
USE_ATTN_SINK
,
true
>::
run
(
params
);
}
else
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
false
,
USE_ATTN_SINK
,
false
>::
run
(
params
);
}
}
};
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
void
run_fwd_phase1_kernel
(
const
SparseAttnFwdParams
&
params
)
{
if
(
D_QK
==
512
&&
HAVE_TOPK_LENGTH
&&
params
.
h_q
==
64
&&
params
.
attn_sink
)
{
KernelTemplate_D512_H64_TopkLen_AttnSink
::
run
(
params
);
}
else
if
(
params
.
h_q
==
64
&&
!
HAVE_TOPK_LENGTH
&&
D_QK
==
576
&&
!
params
.
attn_sink
)
{
if
(
params
.
topk
==
2048
)
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
>::
run
(
params
);
}
else
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
false
>::
run
(
params
);
if
(
params
.
h_q
==
64
)
{
if
(
params
.
attn_sink
)
{
run_h64_fast_path
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
>
(
params
);
}
else
{
run_h64_fast_path
<
D_QK
,
HAVE_TOPK_LENGTH
,
false
>
(
params
);
}
return
;
}
else
{
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
run
(
params
);
}
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
run
(
params
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment