Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
dd5d4bb3
Commit
dd5d4bb3
authored
Jan 28, 2026
by
zhanghj2
Browse files
区分dim576和512
parent
c3cf875a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
27 deletions
+58
-27
csrc/sm90/prefill/sparse/phase1.cuh
csrc/sm90/prefill/sparse/phase1.cuh
+58
-27
No files found.
csrc/sm90/prefill/sparse/phase1.cuh
View file @
dd5d4bb3
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "../../helpers.h"
#include "../../helpers.h"
namespace
sm90
::
fwd
{
namespace
sm90
::
fwd
{
#define CUDART_L2E_F 1.442695041F
using
namespace
cute
;
using
namespace
cute
;
...
@@ -54,8 +55,11 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -54,8 +55,11 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
0
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
0
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
1
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
1
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
2
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
2
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
3
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
true
,
true
>
(
gQ
,
sQ
,
3
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
flash
::
lds_direct_copy
<
false
,
false
,
true
>
(
gQ
,
sQ
,
4
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
if
constexpr
(
D_QK
==
576
)
{
flash
::
lds_direct_copy
<
false
,
false
,
true
>
(
gQ
,
sQ
,
4
,
params
.
stride_q_h_q
,
params
.
h_q
-
bidh
*
kBlockM
);
}
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
...
@@ -64,29 +68,56 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -64,29 +68,56 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
if
constexpr
(
D_QK
==
576
)
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
{
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
16
),
tSrQ_copy_view
(
_
,
_
,
16
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
17
),
tSrQ_copy_view
(
_
,
_
,
17
));
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
16
),
tSrQ_copy_view
(
_
,
_
,
16
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
17
),
tSrQ_copy_view
(
_
,
_
,
17
));
}
else
{
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
}
__syncthreads
();
__syncthreads
();
...
@@ -383,7 +414,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -383,7 +414,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
float
*
gMax_logits
=
reinterpret_cast
<
float
*>
(
params
.
max_logits
)
+
row_offset_lse
;
float
*
gMax_logits
=
reinterpret_cast
<
float
*>
(
params
.
max_logits
)
+
row_offset_lse
;
if
(
params
.
attn_sink
!=
nullptr
)
{
if
(
params
.
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
start_head_idx
+
lane_idx
%
16
);
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
bidh
*
kBlockM
+
lane_idx
%
16
);
if
(
flash
::
is_positive_infinity
(
rAttn_sink
))
if
(
flash
::
is_positive_infinity
(
rAttn_sink
))
{
{
for
(
int
i
=
0
;
i
<
size
(
acc_o
);
i
++
)
for
(
int
i
=
0
;
i
<
size
(
acc_o
);
i
++
)
...
@@ -455,7 +486,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
...
@@ -455,7 +486,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
KU_ASSERT
(
params
.
h_q
%
B_H
==
0
);
KU_ASSERT
(
params
.
h_q
%
B_H
==
0
);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>>
;
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>>
;
constexpr
size_t
smem_size
=
65536
;
constexpr
size_t
smem_size
=
65536
;
dim3
grid
(
(
params
.
s_q
,
params
.
h_q
/
B_H
)
,
1
);
dim3
grid
(
params
.
s_q
,
params
.
h_q
/
B_H
,
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
KU_CHECK_KERNEL_LAUNCH
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment