Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
29b0d13b
Commit
29b0d13b
authored
Apr 14, 2026
by
zhanghj2
Browse files
优化sparse decode fp8
parent
4648ec2f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
94 additions
and
30 deletions
+94
-30
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
+72
-30
csrc/utils.h
csrc/utils.h
+22
-0
No files found.
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
View file @
29b0d13b
...
...
@@ -627,14 +627,14 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
1
,
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
1
,
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
1
,
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
1
,
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
1
,
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
1
,
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
1
,
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
1
,
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
__syncthreads
();
// if (block0() && threadIdx.x >= 192)
...
...
@@ -681,10 +681,10 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
__syncthreads
();
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow
_alt
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
if
constexpr
(
MODEL_TYPE
==
ModelType
::
V32
)
{
...
...
@@ -735,21 +735,21 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
{
// __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col
<
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
_alt
<
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
// __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col
<
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
_alt
<
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
// __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
0
),
tOrVt
(
_
,
_
,
0
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
1
),
tOrVt
(
_
,
_
,
1
),
acc_o
);
// __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col
<
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
_alt
<
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
// __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view);
// __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view);
flash
::
__ds_read_m32x16_row_col
<
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
_alt
<
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
// __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view);
...
...
@@ -774,6 +774,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// {
// printf(" %.4f %.4f \n", acc_o(0), acc_o(1));
// }
auto
float2bf16
=
[]
(
float
s
)
->
uint16_t
{
uint32_t
x32
=
reinterpret_cast
<
uint32_t
const
&>
(
s
);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32
+=
0x8000u
;
#endif
return
uint16_t
(
x32
>>
16
);
};
if
(
args
.
is_no_split
)
{
int
start_head_idx
=
head_block_idx
*
BLOCK_M
;
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
false
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
sm_scale
);
...
...
@@ -805,10 +815,24 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
}
// if (block0() && tidx % 16 == 0)
// {
// printf(" tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %3f \n ",
// tidx,
// float(acc_o(0)),
// float(acc_o(1)),
// float(acc_o(2)),
// float(acc_o(3)),
// float(acc_o(4)),
// float(acc_o(5)),
// float(acc_o(6)),
// float(acc_o(7))
// );
// }
float
*
gSoftmaxLse
=
(
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
start_head_idx
+
s_q_idx
*
params
.
stride_lse_s_q
;
// (BLOCK_M) : (1)
{
auto
rO
=
flash
::
convert_type
<
Element
>
(
acc_o
);
// auto rO = flash::convert_type<Element>(acc_o);
using
result_type
=
cutlass
::
Array
<
Element
,
2
>
;
int
row
,
col
;
const
int
warpId
=
tidx
/
64
;
const
int
laneId
=
tidx
%
64
;
...
...
@@ -829,13 +853,29 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
|
|
*/
col
=
(
laneId
/
16
)
+
ni
*
128
+
(
warpId
%
2
)
*
8
+
(
warpId
/
2
)
*
64
;
col
=
(
laneId
/
16
)
*
2
+
ni
*
128
+
(
warpId
%
2
)
*
8
+
(
warpId
/
2
)
*
64
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
gO
(
row
,
col
)
=
rO
(
i
*
2
+
j
,
mi
,
ni
);
col
+=
4
;
}
col
+=
8
;
#if defined(__gfx938__)
auto
d
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acc_o
(
i
,
mi
,
ni
),
0
,
acc_o
(
i
+
4
,
mi
,
ni
),
0
);
auto
res
=
reinterpret_cast
<
result_type
const
&>
(
d
);
#else
result_type
res
;
Element
e0
,
e1
;
e0
.
storage
=
float2bf16
(
acc_o
(
i
,
mi
,
ni
));
e1
.
storage
=
float2bf16
(
acc_o
(
i
+
4
,
mi
,
ni
));
res
[
0
]
=
e0
;
res
[
1
]
=
e1
;
#endif
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*
(
result_type
*
)(
&
gO
(
row
,
col
))
=
res
;
col
+=
16
;
// for (int j = 0; j < 2; j++) {
// gO(row, col) = rO(i * 2 + j, mi, ni);
// col += 4;
// }
// col += 8;
}
// for (int ei = 0; ei < size<0>(acc_o); ++ei) {
// gO(row, col) = rO(ei, mi, ni);
...
...
@@ -883,13 +923,15 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
// gOaccum(row, col) = acc_o(ei, mi, ni);
// col += 4;
// }
col
=
(
laneId
/
16
)
+
ni
*
128
+
(
warpId
%
2
)
*
8
+
(
warpId
/
2
)
*
64
;
col
=
(
laneId
/
16
)
*
2
+
ni
*
128
+
(
warpId
%
2
)
*
8
+
(
warpId
/
2
)
*
64
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
gOaccum
(
row
,
col
)
=
acc_o
(
i
*
2
+
j
,
mi
,
ni
);
col
+=
4
;
}
col
+=
8
;
gOaccum
(
row
,
col
)
=
acc_o
(
i
,
mi
,
ni
);
gOaccum
(
row
,
col
+
1
)
=
acc_o
(
i
+
4
,
mi
,
ni
);
// for (int j = 0; j < 2; j++) {
// gOaccum(row, col) = acc_o(i * 2 + j, mi, ni);
// col += 4;
// }
col
+=
16
;
}
}
...
...
csrc/utils.h
View file @
29b0d13b
...
...
@@ -256,6 +256,28 @@ __forceinline__ __device__ void __ds_read_m32x16_row_col(Tensor0& src, Tensor1&
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
template
<
int
row
,
int
col
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
void
__ds_read_m32x16_row_col_alt
(
Tensor0
&
src
,
Tensor1
&
dst
)
{
auto
lds
=
reinterpret_cast
<
__fp16
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
2
;
auto
d
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
lds
),
offset
);
uint16_t
*
d_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
d
);
uint16_t
*
dst_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
(
dst
(
0
,
row
,
col
)));
dst_ptr
[
0
]
=
d_ptr
[
0
];
dst_ptr
[
1
]
=
d_ptr
[
1
];
dst_ptr
[
2
]
=
d_ptr
[
2
];
dst_ptr
[
3
]
=
d_ptr
[
3
];
dst_ptr
[
4
]
=
d_ptr
[
4
];
dst_ptr
[
5
]
=
d_ptr
[
5
];
dst_ptr
[
6
]
=
d_ptr
[
6
];
dst_ptr
[
7
]
=
d_ptr
[
7
];
}
inline
__device__
float
fp8e4m3_to_fp32
(
const
fp8
&
input
)
{
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment