Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
b4f69d84
Commit
b4f69d84
authored
May 26, 2026
by
zhanghj2
Browse files
opt h_q 128 sparse prefill
parent
e83a4119
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
12 deletions
+37
-12
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+37
-12
No files found.
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
b4f69d84
...
@@ -98,8 +98,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -98,8 +98,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
}
else
{
row_offset
=
gIndices
[
row_offset
];
row_offset
=
gIndices
[
row_offset
];
}
}
...
@@ -111,12 +109,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -111,12 +109,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4;
// int col = lane_idx % 4;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
// int col_offset = col * 8 + warp_idx * 32;
// int col_offset = col * 8 + warp_idx * 32;
if
constexpr
(
IS_TOPK_2048
)
{
if
(
HAVE_TOPK_LENGTH
&&
row_offset
>=
topk_length
)
{
return
params
.
s_kv
;
}
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
}
else
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
row_offset
=
gIndices
[
row_offset
];
}
}
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
...
@@ -183,6 +181,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -183,6 +181,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
)
{
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
)
{
constexpr
int
element_size
=
2
;
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
col_offset
=
col
;
int
offset_v
=
col_offset
*
2
;
int
offset_v
=
col_offset
*
2
;
...
@@ -208,6 +220,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -208,6 +220,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
)
{
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
)
{
constexpr
int
element_size
=
2
;
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
col_offset
=
col
;
// int v_idx = row_offset;
// int v_idx = row_offset;
...
@@ -302,9 +328,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -302,9 +328,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
}
}
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
const
int
row_in_topk
=
row_
+
warp_idx
*
16
+
block_idx
*
kBlockN
;
const
int
row_in_topk
=
row_
+
warp_idx
*
16
+
block_idx
*
kBlockN
;
if
(
!
IS_TOPK_2048
&&
HAVE_TOPK_LENGTH
&&
block_idx
==
num_topk_blocks
-
1
&&
row_in_topk
>=
topk_length
)
{
if
(
HAVE_TOPK_LENGTH
&&
row_in_topk
>=
topk_length
)
{
row_offset
=
-
1
;
row_offset
=
-
1
;
}
}
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
#if 1
#if 1
#define LOAD_K_AND_QK_GEMM(k) \
#define LOAD_K_AND_QK_GEMM(k) \
...
@@ -370,10 +397,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -370,10 +397,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
t
;
int
t
;
if
constexpr
(
IS_TOPK_2048
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
t
=
sIndices
[
offs
%
1024
];
t
=
sIndices
[
offs
%
1024
];
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
}
else
{
t
=
gIndices
[
offs
];
t
=
gIndices
[
offs
];
}
}
...
@@ -1272,7 +1297,7 @@ static void run_h64_fast_path(const SparseAttnFwdParams& params) {
...
@@ -1272,7 +1297,7 @@ static void run_h64_fast_path(const SparseAttnFwdParams& params) {
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
void
run_fwd_phase1_kernel
(
const
SparseAttnFwdParams
&
params
)
{
void
run_fwd_phase1_kernel
(
const
SparseAttnFwdParams
&
params
)
{
if
(
params
.
h_q
==
64
)
{
if
(
params
.
h_q
==
64
||
params
.
h_q
==
128
)
{
if
(
params
.
attn_sink
)
{
if
(
params
.
attn_sink
)
{
run_h64_fast_path
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
>
(
params
);
run_h64_fast_path
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
>
(
params
);
}
else
{
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment