Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
98b7c697
Commit
98b7c697
authored
Feb 27, 2026
by
zhanghj2
Browse files
fp8 tp1性能提升
parent
24c52aee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
367 additions
and
346 deletions
+367
-346
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+311
-346
csrc/extension/utils.h
csrc/extension/utils.h
+56
-0
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
98b7c697
...
...
@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
using
index_t
=
typename
Kernel_traits
::
index_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
warp_id
=
tidx
/
64
;
const
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
)
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
...
...
@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
intx2_t
p
[
2
];
int32_t
fp8_array
[
4
];
};
#if 0
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
__syncthreads();
#else
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
gQ
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
0
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
1
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
3
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
false
>
(
gQ
,
sQ
,
4
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
uint8_t
*
q_lds_read_ptr
=
reinterpret_cast
<
uint8_t
*>
(
sQ
.
data
().
get
())
+
(
tidx
%
64
)
*
16
+
(
warp_id
%
4
)
*
(
16
*
64
);
{
int
k
=
0
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
Fp8_storage
q_r
[
9
];
#if 1
auto
gQ_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
q_row_stride
;
const
int
q_zero_pad
=
std
::
min
(
std
::
max
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
+
1
)
*
16
-
params
.
seqlen_q
,
0
),
16
);
uint32x4_t
gQ_rscr
=
make_rscr
((
unsigned
char
*
)(
gQ
.
data
().
get
()
+
gQ_offset
),
params
.
q_row_stride
,
q_zero_pad
);
auto
q_lds_addr
=
reinterpret_cast
<
size_t
>
(
sQ
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
)
|
0x80000000
;
if
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
)
*
16
<
params
.
seqlen_q
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
0
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
128
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
else
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
]
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
)
;
}
// int k = 0;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
{
q_lds_read_ptr
+=
64
*
64
;
int
k
=
2
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
else
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 2;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
0
);
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
1
);
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
2
);
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
3
);
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
);
}
auto
q_lds_read_ptr
=
sQ
.
data
().
get
()
+
(
warp_id
%
4
)
*
16
*
64
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
q_r
[
0
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
0
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
1
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
// q_lds_read_ptr += 64 * 64;
q_r
[
2
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
2
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
3
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
3
*
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
{
q_lds_read_ptr
+=
64
*
64
;
int
k
=
4
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r
[
4
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
4
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
5
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
5
*
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
{
q_lds_read_ptr
+=
64
*
64
;
int
k
=
6
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r
[
6
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
6
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
7
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
7
*
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
{
q_lds_read_ptr
+=
64
*
64
;
int
k
=
8
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
// q_lds_read_ptr += 64 * 64;
q_r
[
8
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
8
*
64
*
64
,
3
,
1
,
0
);
__syncthreads
();
#endif
...
...
@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32
[
i
].
w
=
0.0
f
;
}
constexpr
static
int
STAGE
=
8
;
#if 1
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
clear
(
acc_s
);
v4f
accs_f32
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// asm volatile("s_barrier \n\t");
...
...
@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
// gK.data() = gK.data() + (offset_k);
#if 1
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8_tp1
<
false
,
false
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const
int
k_zero_pad
=
std
::
max
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
if
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
||
masking_step
!=
0
)
{
k_lds_addr
|=
0x80000000
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
0
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
1
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
2
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
3
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
auto
k_lds_read_ptr
=
sK
.
data
().
get
()
+
(
warp_id
/
4
)
*
16
*
64
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
constexpr
static
int
k_read_lds_offset
=
32
*
64
;
{
constexpr
static
int
k_idx
=
0
;
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
1
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
{
constexpr
static
int
k_idx
=
2
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
3
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
{
constexpr
static
int
k_idx
=
4
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
5
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
{
constexpr
static
int
k_idx
=
6
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
7
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
8
),
tSrK_copy_view
(
_
,
_
,
8
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// asm volatile("s_barrier \n\t");
{
constexpr
static
int
k_idx
=
8
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
#endif
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
...
...
@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
Fp8_storage
p_fp8
;
{
...
...
@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads
();
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
__builtin_amdgcn_sched_barrier
(
0
);
}
{
__builtin_amdgcn_sched_barrier
(
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
int
k
=
0
;
Fp8_storage
v0_0
,
v0_1
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))));
int
lane_id
=
tidx
%
64
;
int
row
=
lane_id
/
4
;
int
col
=
lane_id
%
4
;
col
=
(
col
+
(
row
/
2
)
%
4
)
%
4
;
auto
lds_offset
=
row
*
64
+
col
*
16
+
(
warp_id
/
4
)
*
64
*
64
;
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// Fp8_storage v0_0, v0_1;
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
Val
tmp
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
}
}
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
int
k
=
2
;
Fp8_storage
v0_0
,
v0_1
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))
));
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
));
// if (block0() && tidx < 64
)
//
{
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0)
;
// float v1 = __builtin_amdgcn_cvt_f32_fp8(
v0_0.fp8_array[
0], 1)
;
// float v2 = __builtin_amdgcn_cvt_f32_fp8(
v0_
0
.fp8_array[
0], 2)
;
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3
);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
]
;
v
[
1
]
=
v0_
1
.
fp8_array
[
j
]
;
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
0
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// }
}
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
+
32
*
64
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
+
32
*
64
));
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
Val
tmp
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
}
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
1
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
}
#endif
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
...
...
@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
if
(
NoSplit
)
{
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
...
@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
using
result_type
=
cutlass
::
Array
<
bfloat16_t
,
2
>
;
int
tidx
=
threadIdx
.
x
;
...
...
@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum
(
row
,
col
+
3
)
=
res1
[
1
];
// col += 16;
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
}
else
{
constexpr
bool
Split
=
true
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
...
@@ -1616,15 +1622,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
+
(
warpid
%
4
)
*
16
;
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
// for (int n = 0; n < 32; n++)
// {
// col = (tidx % 64 / 16) * 4 + n * 16;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
col
=
(
tidx
%
64
/
16
)
*
16
+
n
*
128
+
(
warp_id
/
4
)
*
64
;
...
...
@@ -1658,44 +1655,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
w
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
w
;
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
SharedStorage
>
__forceinline__
__device__
void
compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4
(
const
Flash_fwd_mla_params
&
params
,
...
...
csrc/extension/utils.h
View file @
98b7c697
...
...
@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
extern
__device__
__attribute__
((
const
))
float
__llvm_exp2_f32
(
float
)
__asm
(
"llvm.exp2.f32"
);
__device__
inline
uint32x4_t
make_rscr
(
unsigned
char
*
ptr
,
const
int
stride
,
const
int
zero_pad
)
{
uint32x4_t
rscr
;
*
(
uint64_t
*
)
&
rscr
=
(
reinterpret_cast
<
uint64_t
>
(
ptr
));
rscr
[
2
]
=
stride
;
rscr
[
3
]
=
(
1
<<
16
)
&
0XFFFFFFFF
;
rscr
[
3
]
|=
(
zero_pad
)
<<
8
;
return
rscr
;
}
template
<
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_zero_lds
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
k_idx_
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
//0-256
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
//0-63
constexpr
int
element_size
=
1
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
//576
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
16
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
//64*16*1
int
offset_v
=-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
(
warp_id
%
4
)
*
bytes_per_warp
+
(
k_idx
)
*
64
*
128
*
element_size
+
(
warp_id
/
4
)
*
64
*
64
;
#if defined(__gfx938__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment