Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
e83a4119
Commit
e83a4119
authored
May 25, 2026
by
zhanghj2
Browse files
tail guard to opt sparse prefill
parent
9a805181
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
32 deletions
+16
-32
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+16
-32
No files found.
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
e83a4119
...
@@ -98,6 +98,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -98,6 +98,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
}
else
{
row_offset
=
gIndices
[
row_offset
];
row_offset
=
gIndices
[
row_offset
];
}
}
...
@@ -109,9 +111,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -109,9 +111,12 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
// int col = lane_idx % 4;
// int col = lane_idx % 4;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
// int col_offset = col * 8 + warp_idx * 32;
// int col_offset = col * 8 + warp_idx * 32;
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
if
constexpr
(
IS_TOPK_2048
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
{
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
row_offset
=
gIndices
[
row_offset
];
}
}
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
...
@@ -178,20 +183,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -178,20 +183,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
)
{
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
)
{
constexpr
int
element_size
=
2
;
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
col_offset
=
col
;
int
offset_v
=
col_offset
*
2
;
int
offset_v
=
col_offset
*
2
;
...
@@ -217,20 +208,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -217,20 +208,6 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
)
{
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
)
{
constexpr
int
element_size
=
2
;
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
col_offset
=
col
;
// int v_idx = row_offset;
// int v_idx = row_offset;
...
@@ -324,6 +301,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -324,6 +301,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
accs_f32
[
i
].
w
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
}
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
const
int
row_in_topk
=
row_
+
warp_idx
*
16
+
block_idx
*
kBlockN
;
if
(
!
IS_TOPK_2048
&&
HAVE_TOPK_LENGTH
&&
block_idx
==
num_topk_blocks
-
1
&&
row_in_topk
>=
topk_length
)
{
row_offset
=
-
1
;
}
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
#if 1
#if 1
#define LOAD_K_AND_QK_GEMM(k) \
#define LOAD_K_AND_QK_GEMM(k) \
...
@@ -389,8 +370,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -389,8 +370,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
t
;
int
t
;
if
constexpr
(
IS_TOPK_2048
||
CACHE_INDICES_IN_LDS
)
{
if
constexpr
(
IS_TOPK_2048
)
{
t
=
sIndices
[
offs
%
1024
];
t
=
sIndices
[
offs
%
1024
];
}
else
if
constexpr
(
CACHE_INDICES_IN_LDS
)
{
row_offset
=
sIndices
[
row_offset
];
}
else
{
}
else
{
t
=
gIndices
[
offs
];
t
=
gIndices
[
offs
];
}
}
...
@@ -629,6 +612,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
...
@@ -629,6 +612,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_
}
}
else
else
{
{
if
(
num_topk_blocks
>
0
)
process_one_block
(
0
,
IsFirstBlock
{});
process_one_block
(
0
,
IsFirstBlock
{});
for
(
int
block_idx
=
1
;
block_idx
<
num_topk_blocks
;
block_idx
++
)
for
(
int
block_idx
=
1
;
block_idx
<
num_topk_blocks
;
block_idx
++
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment