Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
14b2cfc5
Commit
14b2cfc5
authored
Apr 07, 2026
by
zhanghj2
Browse files
优化 nmz和bmz dsa prefill,nhead=64
parent
a9ef79c6
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
994 additions
and
3 deletions
+994
-3
csrc/gfx93/prefill/sparse/config.h
csrc/gfx93/prefill/sparse/config.h
+30
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+786
-1
csrc/softmax.h
csrc/softmax.h
+89
-0
csrc/utils.h
csrc/utils.h
+85
-0
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+4
-2
No files found.
csrc/gfx93/prefill/sparse/config.h
View file @
14b2cfc5
...
...
@@ -124,5 +124,35 @@ static void run(const SparseAttnFwdParams ¶ms);
};
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
class
KernelTemplate_B_H_64
{
public:
static
constexpr
int
D_Q
=
D_QK
;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
kNWarps
=
4
;
static
constexpr
int
B_H
=
64
;
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
NUM_THREADS
=
kNWarps
*
64
;
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
using
Element
=
cutlass
::
bfloat16_t
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
using
index_t
=
int64_t
;
static
constexpr
int
kBlockM
=
B_H
;
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
static
constexpr
int
kHeadDimV
=
D_V
;
static
__device__
__forceinline__
void
devfunc
(
const
SparseAttnFwdParams
&
params
);
static
void
run
(
const
SparseAttnFwdParams
&
params
);
};
};
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
14b2cfc5
This diff is collapsed.
Click to expand it.
csrc/softmax.h
View file @
14b2cfc5
...
...
@@ -602,6 +602,95 @@ struct Softmax {
}
return
lse
;
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
>
__forceinline__
__device__
void
softmax_rescale_o_prefill_4x1
(
Tensor0
&
scores
,
v4f
*
acc_o
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
constexpr
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
true
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum
(
mi
)
*=
scores_scale
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acc_o
[
i
].
x
*=
scores_scale
;
acc_o
[
i
].
y
*=
scores_scale
;
acc_o
[
i
].
z
*=
scores_scale
;
acc_o
[
i
].
w
*=
scores_scale
;
}
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
// if (thread0())
// {
// printf("max sum %.3f %.3f \n", row_max(0), row_sum(0));
// }
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_prefill_4x1
(
v4f
*
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
// flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT
lse
=
make_fragment_like
(
row_sum
);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
// if (thread0())
// {
// printf(" %.3f %.3f \n", row_max(0), row_sum(0));
// }
#pragma unroll
for
(
int
mi
=
0
;
mi
<
1
;
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acc_o
[
i
].
x
*=
scale
;
acc_o
[
i
].
y
*=
scale
;
acc_o
[
i
].
z
*=
scale
;
acc_o
[
i
].
w
*=
scale
;
}
}
return
lse
;
};
};
...
...
csrc/utils.h
View file @
14b2cfc5
...
...
@@ -1523,6 +1523,91 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
}
#endif
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
template
<
typename
Element
,
int
k_idx
>
__forceinline__
__device__
void
qk_gemm
(
const
__fp16x8_t
&
q_data
,
Element
*
k_lds_read_ptr
,
v4f
*
accs_f32
)
{
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_idx_even
=
k_idx
%
4
;
constexpr
int
n_offset
=
16
*
32
;
constexpr
int
k_offset
=
k_idx_even
*
64
*
32
;
Bf16_storage
q_reg
;
Bf16_storage
k_reg
;
q_reg
.
data_128
=
q_data
;
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
);
// q_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(q_lds_read_ptr), k_offset, 2, 1, 0);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 0 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
0
],
true
,
false
);
#else
accs_f32
[
0
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
0
]);
accs_f32
[
0
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
0
]);
#endif
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
1
*
n_offset
);
#if defined(__gfx938__)
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
1
],
true
,
false
);
#else
accs_f32
[
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
1
]);
accs_f32
[
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
1
]);
#endif
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 1 * n_offset + k_offset, 2, 1, 0);
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
2
*
n_offset
);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 2 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
2
],
true
,
false
);
accs_f32
[
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
2
],
true
,
false
);
#else
accs_f32
[
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
2
]);
accs_f32
[
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
2
]);
#endif
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
3
*
n_offset
);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 3 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
3
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
3
],
true
,
false
);
accs_f32
[
3
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
3
],
true
,
false
);
#else
accs_f32
[
3
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
3
]);
accs_f32
[
3
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
3
]);
#endif
}
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
int
k_idx
,
int
n_idx_val
>
__forceinline__
__device__
void
pv_gemm
(
const
__fp16x4_t
&
p
,
int
v_lds_read_ptr
,
v4f
*
acco_f32
)
{
constexpr
int
k_idx_even
=
k_idx
%
1
;
constexpr
int
n_offset
=
16
*
32
*
2
;
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_offset
=
k_idx_even
*
16
*
512
*
2
;
// #if 1
Bf16_storage
v_reg
;
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
n_idx_val
*
n_offset
);
#if defined(__gfx938__)
acco_f32
[
n_idx_val
*
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
],
true
,
false
);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
],
true
,
false
);
#else
acco_f32
[
n_idx_val
*
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
]);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
]);
#endif
}
}
\ No newline at end of file
tests/test_flash_mla_sparse_prefill.py
View file @
14b2cfc5
...
...
@@ -77,7 +77,7 @@ if __name__ == '__main__':
(
1840
,
256
),
(
1592
,
384
),
(
1521
,
512
),
(
3000
,
2048
),
# Irregular shapes with OOB TopK
(
95
,
128
),
(
153
,
256
),
...
...
@@ -146,6 +146,7 @@ if __name__ == '__main__':
performance_case_templates
=
[
# V3.2
(
576
,
128
,
2048
,
[
8192
,
32768
,
65536
,
98304
,
131072
]),
(
576
,
64
,
2048
,
[
8192
,
32768
,
65536
,
98304
,
131072
]),
# MODEL1 CONFIG1
(
512
,
64
,
512
,
[
8192
,
32768
,
49152
,
65536
]),
# MODEL1 CONFIG2
...
...
@@ -154,9 +155,10 @@ if __name__ == '__main__':
]
performance_cases
=
[
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
d_qk
=
d_qk
,
have_attn_sink
=
True
)
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
d_qk
=
d_qk
,
have_attn_sink
=
have_attn_sink
)
for
(
d_qk
,
h_q
,
topk
,
s_kv_list
)
in
performance_case_templates
for
s_q
in
[
4096
]
for
have_attn_sink
in
[
False
,
True
]
for
s_kv
in
s_kv_list
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment