Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
4d897ed1
Commit
4d897ed1
authored
Feb 28, 2026
by
zhanghj2
Browse files
优化nmz tp1性能
parent
3722ec71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
345 deletions
+87
-345
csrc/extension/flash_api.h
csrc/extension/flash_api.h
+2
-2
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+85
-343
No files found.
csrc/extension/flash_api.h
View file @
4d897ed1
...
...
@@ -683,9 +683,9 @@ mha_fwd_kvcache_mla_fp8(
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
static
std
::
string
FLASH_MLA_ROOT_DIR
=
execCommand
(
"python -c 'import site; print(site.getsitepackages()[0])'"
);
//
static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'");
setenv
(
"FLASH_MLA_ROOT_DIR"
,
(
FLASH_MLA_ROOT_DIR
+
"/flash_mla/asm/"
).
c_str
(),
1
);
//
setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// std::cout << FLASH_MLA_ROOT_DIR << "\n";
// exit(-1);
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
...
...
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
4d897ed1
...
...
@@ -1299,6 +1299,41 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads
();
auto
k_lds_read_ptr
=
sK
.
data
().
get
()
+
(
warp_id
/
4
)
*
16
*
64
;
constexpr
static
int
k_read_lds_offset
=
32
*
64
;
// Fp8_storage data[9];
#if 0
Fp8_storage k_data[9];
__builtin_amdgcn_sched_barrier(0);
k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096, 3, 1, 0);
k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096, 3, 1, 0);
k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096, 3, 1, 0);
k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096, 3, 1, 0);
k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096, 3, 1, 0);
k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096, 3, 1, 0);
k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096, 3, 1, 0);
k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096, 3, 1, 0);
k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096, 3, 1, 0);
#pragma unroll
for (int i = 0; i < 9; i++) {
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[0], true, false);
}
k_data[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 0 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 1 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 2 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 3 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 4 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 5 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 6 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 7 * 4096 + k_read_lds_offset, 3, 1, 0);
k_data[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), 8 * 4096 + k_read_lds_offset, 3, 1, 0);
#pragma unroll
for (int i = 0; i < 9; i++) {
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[0], k_data[i].p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[i].p[1], k_data[i].p[1], accs_f32[1], true, false);
}
__builtin_amdgcn_sched_barrier(0);
#else
{
constexpr
static
int
k_idx
=
0
;
// k_lds_read_ptr += k_idx * 64 * 64;
...
...
@@ -1383,7 +1418,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
6
;
// k_lds_read_ptr += 64 * 64;
...
...
@@ -1425,6 +1459,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
#endif
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
...
...
@@ -1534,299 +1569,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// asm volatile("s_barrier \n\t");
};
#if 0
auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
v4f accs_f32[2];
for (int i = 0; i < 2; i++)
{
accs_f32[i].x = 0.0f;
accs_f32[i].y = 0.0f;
accs_f32[i].z = 0.0f;
accs_f32[i].w = 0.0f;
}
Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// asm volatile("s_barrier \n\t");
int cur_block_table;
const int *cur_block_table_ptr = block_table + block_idx;
// cur_block_table = block_table[block_idx - 1];
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
// gK.data() = gK.data() + (offset_k);
gK.data() = gK.data() + (offset_k);
auto gK_offset = ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const int k_zero_pad = std::max(block_idx * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0);
uint32x4_t gK_rscr = make_rscr((unsigned char*)(gK.data().get() + gK_offset), params.k_row_stride, k_zero_pad);
auto k_lds_addr = reinterpret_cast<size_t>(sK.data().get() + ((warp_id) / 4) * 64 * 64 + (warp_id % 4) * 16 * 64);
if (block_idx * kBlockN + ((warp_id) % 4) * 16 < seqlen_k || IS_NO_MASK_BLOCK)
{
k_lds_addr |= 0x80000000;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 0, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 256+128, 1, 1, 0, 0);
k_lds_addr += 64 * 128;
if (warp_id < 4)
{
__builtin_hcu_matrix_load_64x16_b8(gK_rscr, (__attribute__((address_space(3))) char*)(k_lds_addr), 512, 1, 1, 0, 0);
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 0);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 1);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 2);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 3);
lds_direct_copy_qkvfp8_zero_lds(gK, sK, 4);
}
gK.data() = gK.data() + ( - offset_k);
auto k_lds_read_ptr = sK.data().get() + (warp_id / 4) * 16 * 64;
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
constexpr static int k_read_lds_offset = 32 * 64;
{
constexpr static int k_idx = 0;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 1;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 2;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 3;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 4;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 5;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 6;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
{
constexpr static int k_idx = 7;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
{
constexpr static int k_idx = 8;
// k_lds_read_ptr += 64 * 64;
Fp8_storage k_data;
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64, 3, 1, 0);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[0], true, false);
accs_f32[0] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[0], true, false);
k_data.data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(k_lds_read_ptr), k_idx * 64 * 64 + k_read_lds_offset, 3, 1, 0);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[0], k_data.p[0], accs_f32[1], true, false);
accs_f32[1] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(q_r[k_idx].p[1], k_data.p[1], accs_f32[1], true, false);
}
#endif
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
acc_s(0, 0, 0) = accs_f32[0].x; acc_s(1, 0, 0) = accs_f32[0].y; acc_s(2, 0, 0) = accs_f32[0].z; acc_s(3, 0, 0) = accs_f32[0].w;
acc_s(0, 0, 1) = accs_f32[1].x; acc_s(1, 0, 1) = accs_f32[1].y; acc_s(2, 0, 1) = accs_f32[1].z; acc_s(3, 0, 1) = accs_f32[1].w;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
// #endif
if constexpr (!IS_NO_MASK_BLOCK) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
for (int i = 0; i < size(acc_s); ++i) {
if constexpr (!Is_causal) {
if (int(get<1>(tScS(i))) >= int(seqlen_k - block_idx * kBlockN)) acc_s(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY;
}
}
}
// asm volatile("s_barrier \n\t");
softmax.template softmax_rescale_o_fp8_tp1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*/Is_causal, true>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, acco_f32);
// #if 1
Fp8_storage p_fp8;
{
__builtin_amdgcn_sched_barrier(0);
int tid = threadIdx.x % 64;
int warp_id = threadIdx.x / 64;
int32_t result;
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 0), acc_s(1, 0, 0), result, false);
result = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 0), acc_s(3, 0, 0), result, true);
// int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64]));
// *lds_ptr = result;
int32_t result1;
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(0, 0, 1), acc_s(1, 0, 1), result1, false);
result1 = __builtin_hcu_cvt_pk_fp8_f32(acc_s(2, 0, 1), acc_s(3, 0, 1), result1, true);
// lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id / 4) * 16 * 32 + (warp_id % 4) * 16 * 64 + 8]));
// *lds_ptr = result1;
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16) * 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64]));
*lds_ptr = result;
int32_t* lds_ptr1 = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid % 64) / 16 )* 16 * 16 + (warp_id / 4) * 4 + (warp_id % 4) * 16 * 64 + 8]));
*lds_ptr1 = result1;
__syncthreads();
p_fp8.data = *reinterpret_cast<intx4_t*>(&(sP[tid * 16 + (warp_id % 4) * 16 * 64]));
__builtin_amdgcn_sched_barrier(0);
int lane_id = tidx % 64;
int row = lane_id / 4;
int col = lane_id % 4;
col = (col + (row / 2) % 4) % 4;
auto lds_offset = row * 64 + col * 16 + (warp_id / 4) * 64 * 64;
// Fp8_storage v0_0, v0_1;
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
for (int n = 0; n < 4; n++)
{
Fp8_storage v0_0, v0_1;
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[0], v, acco_f32[n * 4 + j], true, false);
}
v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + n * 64 * 128 + 32 * 64));
v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64 + n * 64 * 128 + 32 * 64));
for (int j = 0; j < 4; j++)
{
intx2_t v;
v[0] = v0_0.fp8_array[j];
v[1] = v0_1.fp8_array[j];
acco_f32[n * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[1], v, acco_f32[n * 4 + j], true, false);
}
}
}
asm volatile("s_barrier \n\t");
};
#endif
#if 1
if
constexpr
(
n_masking_steps
==
1
)
{
...
...
@@ -2824,9 +2566,9 @@ template<typename Kernel_traits, typename SharedStorage>
void
run_flash_splitkv_fwd_mla_fp8_tp1
(
Flash_fwd_mla_params
&
params
,
cudaStream_t
stream
)
{
FLASH_ASSERT
(
params
.
page_block_size
==
Kernel_traits
::
kBlockN
);
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
seqlen_q
,
Kernel_traits
::
kBlockM
);
const
static
bool
dis
able_asm
=
get_env_
(
"FLASH_MLA_
DIS
ABLE_ASM"
);
//
const static bool
en
able_asm = get_env_("FLASH_MLA_
EN
ABLE_ASM");
if
(
disable_asm
)
{
if
(
1
)
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
auto
kernel
=
&
flash
::
flash_fwd_splitkv_mla_kernel_fp8_tp1
<
Kernel_traits
,
Is_causal
,
SharedStorage
>
;
constexpr
size_t
smem_size
=
65536
;
...
...
@@ -2834,54 +2576,54 @@ void run_flash_splitkv_fwd_mla_fp8_tp1(Flash_fwd_mla_params ¶ms, cudaStream_
kernel
<<<
dim3
(
num_m_block
,
params
.
h
,
params
.
num_sm_parts
),
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
});
}
else
{
static
char
*
FLASH_MLA_ASM_DIR
=
std
::
getenv
(
"FLASH_MLA_ROOT_DIR"
);
assert
(
FLASH_MLA_ASM_DIR
!=
nullptr
&&
"FLASH_MLA_ASM_DIR nullptr
\n
"
);
constexpr
size_t
smem_size
=
65536
;
std
::
string
co_file
=
std
::
string
(
FLASH_MLA_ASM_DIR
)
+
"flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co"
;
hipError_t
status
=
hipSuccess
;
static
hipModule_t
fwd_module_sample
;
static
bool
IS_FWD_MODULE_LOADED
=
false
;
//
else {
//
static char* FLASH_MLA_ASM_DIR = std::getenv("FLASH_MLA_ROOT_DIR");
//
assert(FLASH_MLA_ASM_DIR != nullptr && "FLASH_MLA_ASM_DIR nullptr \n");
//
constexpr size_t smem_size = 65536;
//
std::string co_file = std::string(FLASH_MLA_ASM_DIR) +
//
"flash_fwd_mla_fp8_gfx938-hip-amdgcn-amd-amdhsa-gfx938.co";
//
hipError_t status = hipSuccess;
//
static hipModule_t fwd_module_sample;
//
static bool IS_FWD_MODULE_LOADED = false;
if
(
IS_FWD_MODULE_LOADED
==
false
)
{
status
=
hipModuleLoad
(
&
fwd_module_sample
,
co_file
.
c_str
());
if
(
status
not_eq
hipSuccess
)
{
printf
(
"[flashmla] EXIT: failed to load module from %s
\n
"
,
co_file
.
c_str
());
return
;
}
IS_FWD_MODULE_LOADED
=
true
;
}
size_t
params_size
=
sizeof
(
params
);
void
*
config
[]
=
{
HIP_LAUNCH_PARAM_BUFFER_POINTER
,
&
params
,
HIP_LAUNCH_PARAM_BUFFER_SIZE
,
&
params_size
,
HIP_LAUNCH_PARAM_END
};
dim3
grid
(
num_m_block
,
params
.
h
,
params
.
num_sm_parts
);
std
::
string
kernel_name
=
params
.
is_causal
?
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params"
:
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params"
;
hipFunction_t
flash_mla_func
;
status
=
hipModuleGetFunction
(
&
flash_mla_func
,
fwd_module_sample
,
kernel_name
.
c_str
());
status
=
hipModuleLaunchKernel
(
flash_mla_func
,
grid
.
x
,
grid
.
y
,
grid
.
z
,
Kernel_traits
::
kNThreads
,
1
,
1
,
smem_size
,
// shared memory
stream
,
// stream
NULL
,
(
void
**
)
&
config
);
if
(
status
not_eq
hipSuccess
)
{
printf
(
"[flashmla] EXIT: failed to launch kernel!
\n
"
);
return
;
}
}
//
if (IS_FWD_MODULE_LOADED == false)
//
{
//
status = hipModuleLoad(&fwd_module_sample, co_file.c_str());
//
if (status not_eq hipSuccess) {
//
printf("[flashmla] EXIT: failed to load module from %s\n", co_file.c_str());
//
return;
//
}
//
IS_FWD_MODULE_LOADED = true;
//
}
//
size_t params_size = sizeof(params);
//
void* config[] = {
//
HIP_LAUNCH_PARAM_BUFFER_POINTER,
//
¶ms,
//
HIP_LAUNCH_PARAM_BUFFER_SIZE,
//
¶ms_size,
//
HIP_LAUNCH_PARAM_END
//
};
//
dim3 grid(num_m_block, params.h, params.num_sm_parts);
//
std::string kernel_name = params.is_causal ?
//
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb1ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params":
//
"_ZN5flash36flash_fwd_splitkv_mla_kernel_fp8_tp1I38Flash_fwd_kernel_traits_mla_qkvfp8_TP1ILi576ELi64ELi64ELi8EN7cutlass12float_e4m3_tENS2_10bfloat16_tELi512EELb0ENS_26SharedStorageMLAFloat8_TP1IS5_EEEEv20Flash_fwd_mla_params";
//
hipFunction_t flash_mla_func;
//
status = hipModuleGetFunction(&flash_mla_func, fwd_module_sample, kernel_name.c_str());
//
status = hipModuleLaunchKernel(
//
flash_mla_func,
//
grid.x, grid.y, grid.z,
//
Kernel_traits::kNThreads, 1, 1,
//
smem_size, // shared memory
//
stream, // stream
//
NULL,
//
(void**)&config
//
);
//
if (status not_eq hipSuccess) {
//
printf("[flashmla] EXIT: failed to launch kernel!\n");
//
return;
//
}
//
}
CHECK_CUDA_KERNEL_LAUNCH
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment