Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
9b54b03c
"platforms/reference/src/ReferenceFloatStreamImpl.h" did not exist on "0e879806cdd38e58b04481ecf7fcd93c44c7dc27"
Commit
9b54b03c
authored
Jan 26, 2026
by
zhanghj2
Browse files
支持attn_sink
parent
5813dcc1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
1 deletion
+12
-1
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+12
-1
No files found.
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
9b54b03c
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
using
namespace
cute
;
using
namespace
cute
;
namespace
sm90
::
decode
::
sparse_fp8
{
namespace
sm90
::
decode
::
sparse_fp8
{
#define CUDART_L2E_F 1.442695041F
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// Prevent (-inf) - (-inf) = nan
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// Prevent (-inf) - (-inf) = nan
...
@@ -403,6 +404,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -403,6 +404,16 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
out
)
+
row_offset_o
),
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
out
)
+
row_offset_o
),
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{},
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{},
make_stride
(
params
.
stride_o_h_q
,
_1
{}));
make_stride
(
params
.
stride_o_h_q
,
_1
{}));
if
(
params
.
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
__ldg
((
float
*
)
params
.
attn_sink
+
start_head_idx
+
lane_idx
%
16
);
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
lane_idx
%
16
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
float
o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
for
(
int
i
=
0
;
i
<
size
(
acc_o
);
i
++
)
{
acc_o
(
i
)
*=
o_scale
;
}
}
float
*
gSoftmaxLse
=
(
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
start_head_idx
+
s_q_idx
*
params
.
stride_lse_s_q
;
// (BLOCK_M) : (1)
float
*
gSoftmaxLse
=
(
float
*
)
params
.
lse
+
batch_idx
*
params
.
stride_lse_b
+
start_head_idx
+
s_q_idx
*
params
.
stride_lse_s_q
;
// (BLOCK_M) : (1)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment