Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
f01246b6
Commit
f01246b6
authored
Mar 23, 2026
by
zhanghj2
Browse files
优化prefill sparse写出
parent
c3a5b02a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
21 deletions
+97
-21
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+75
-21
csrc/utils.h
csrc/utils.h
+22
-0
No files found.
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
f01246b6
...
...
@@ -204,10 +204,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
4
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
...
...
@@ -233,10 +233,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
...
...
@@ -263,10 +263,10 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
...
...
@@ -332,7 +332,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
{
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
...
...
@@ -344,7 +344,7 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
0
),
tOrVt
(
_
,
_
,
0
),
acc_o
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
1
),
tOrVt
(
_
,
_
,
1
),
acc_o
);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view);
...
...
@@ -355,9 +355,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
2
),
tOrVt
(
_
,
_
,
2
),
acc_o
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
3
),
tOrVt
(
_
,
_
,
3
),
acc_o
);
// for (int i = 0; i < size(tOrP); i++)
...
...
@@ -449,7 +449,18 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
{
// store O and gLSE
auto
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
// auto rO = flash::convert_type<Element>(acc_o);
auto
float2bf16
=
[]
(
float
s
)
->
uint16_t
{
uint32_t
x32
=
reinterpret_cast
<
uint32_t
const
&>
(
s
);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32
+=
0x8000u
;
#endif
return
uint16_t
(
x32
>>
16
);
};
int
row
,
col
;
const
int
warpId
=
tidx
/
64
;
const
int
laneId
=
tidx
%
64
;
...
...
@@ -457,11 +468,54 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
row
=
mi
*
kBlockM
+
laneId
%
16
;
if
(
row
<
params
.
h_q
)
{
for
(
int
ni
=
0
;
ni
<
size
<
2
>
(
acc_o
);
++
ni
)
{
col
=
(
laneId
/
16
)
+
ni
*
128
+
warpId
*
32
;
for
(
int
ei
=
0
;
ei
<
size
<
0
>
(
acc_o
);
++
ei
)
{
gO
(
row
,
col
)
=
rO
(
ei
,
mi
,
ni
);
col
+=
4
;
col
=
(
laneId
/
16
)
*
2
+
ni
*
128
+
warpId
*
32
;
using
result_type
=
cutlass
::
Array
<
Element
,
2
>
;
for
(
int
ei
=
0
;
ei
<
4
;
ei
++
)
{
#if defined(__gfx938__)
auto
d
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acc_o
(
ei
,
mi
,
ni
),
0
,
acc_o
(
ei
+
4
,
mi
,
ni
),
0
);
auto
res
=
reinterpret_cast
<
result_type
const
&>
(
d
);
#else
result_type
res
;
Element
e0
,
e1
;
e0
.
storage
=
float2bf16
(
acc_o
(
ei
,
mi
,
ni
));
e1
.
storage
=
float2bf16
(
acc_o
(
ei
+
4
,
mi
,
ni
));
res
[
0
]
=
e0
;
res
[
1
]
=
e1
;
#endif
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*
(
result_type
*
)(
&
gO
(
row
,
col
))
=
res
;
col
+=
8
;
}
// gO(row, col) = rO(0, mi, ni);
// gO(row, col + 1) = rO(1, mi, ni);
// col += 8;
// gO(row, col) = rO(2, mi, ni);
// gO(row, col + 1) = rO(3, mi, ni);
// col += 8;
// gO(row, col) = rO(4, mi, ni);
// gO(row, col + 1) = rO(5, mi, ni);
// col += 8;
// gO(row, col) = rO(6, mi, ni);
// gO(row, col + 1) = rO(7, mi, ni);
// gO(row, col) = rO(0, mi, ni);
// gO(row, col + 1) = rO(4, mi, ni);
// col += 8;
// gO(row, col) = rO(1, mi, ni);
// gO(row, col + 1) = rO(5, mi, ni);
// col += 8;
// gO(row, col) = rO(2, mi, ni);
// gO(row, col + 1) = rO(6, mi, ni);
// col += 8;
// gO(row, col) = rO(3, mi, ni);
// gO(row, col + 1) = rO(7, mi, ni);
// for (int ei = 0; ei < size<0>(acc_o); ei += 2) {
// gO(row, col) = rO(ei, mi, ni);
// col += 4;
// }
}
gLSE
[
row
]
=
lse
(
mi
);
gMax_logits
[
row
]
=
topk_length
==
0
?
-
INFINITY
:
softmax
.
row_max
(
mi
)
*
params
.
sm_scale
;
...
...
csrc/utils.h
View file @
f01246b6
...
...
@@ -211,6 +211,28 @@ __forceinline__ __device__ void __ds_read_m32x16_row_col_rrow(Tensor0& src, Ten
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
template
<
int
row
,
int
col
,
int
r_row
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
void
__ds_read_m32x16_row_col_rrow_alt
(
Tensor0
&
src
,
Tensor1
&
dst
)
{
auto
lds
=
reinterpret_cast
<
__fp16
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
2
;
auto
d
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
lds
),
offset
);
uint16_t
*
d_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
d
);
uint16_t
*
dst_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
(
dst
(
0
,
r_row
,
col
)));
dst_ptr
[
0
]
=
d_ptr
[
0
];
dst_ptr
[
1
]
=
d_ptr
[
1
];
dst_ptr
[
2
]
=
d_ptr
[
2
];
dst_ptr
[
3
]
=
d_ptr
[
3
];
dst_ptr
[
4
]
=
d_ptr
[
4
];
dst_ptr
[
5
]
=
d_ptr
[
5
];
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
void
__ds_read_m32x16_row_col
(
Tensor0
&
src
,
Tensor1
&
dst
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment