Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
98b7c697
"platforms/opencl/tests/TestOpenCLRGForce.cpp" did not exist on "0e879806cdd38e58b04481ecf7fcd93c44c7dc27"
Commit
98b7c697
authored
Feb 27, 2026
by
zhanghj2
Browse files
fp8 tp1性能提升
parent
24c52aee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
367 additions
and
346 deletions
+367
-346
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+311
-346
csrc/extension/utils.h
csrc/extension/utils.h
+56
-0
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
98b7c697
...
@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
warp_id
=
tidx
/
64
;
const
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
)
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
...
@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
intx2_t
p
[
2
];
intx2_t
p
[
2
];
int32_t
fp8_array
[
4
];
int32_t
fp8_array
[
4
];
};
};
#if 0
Fp8_storage
q_r
[
9
];
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
#if 1
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
auto
gQ_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
q_row_stride
;
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
const
int
q_zero_pad
=
std
::
min
(
std
::
max
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
+
1
)
*
16
-
params
.
seqlen_q
,
0
),
16
);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
uint32x4_t
gQ_rscr
=
make_rscr
((
unsigned
char
*
)(
gQ
.
data
().
get
()
+
gQ_offset
),
params
.
q_row_stride
,
q_zero_pad
);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
auto
q_lds_addr
=
reinterpret_cast
<
size_t
>
(
sQ
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
)
|
0x80000000
;
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
if
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
)
*
16
<
params
.
seqlen_q
)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
{
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
0
,
1
,
1
,
0
,
0
);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
q_lds_addr
+=
64
*
128
;
params.seqlen_q - m_block * kBlockM);
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
128
,
1
,
1
,
0
,
0
);
__syncthreads();
q_lds_addr
+=
64
*
128
;
#else
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
,
1
,
1
,
0
,
0
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
gQ
);
q_lds_addr
+=
64
*
128
;
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
0
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
1
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
q_lds_addr
+=
64
*
128
;
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
if
(
warp_id
<
4
)
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
3
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
{
lds_direct_copy_qkvfp8_q_tp1
<
false
,
false
>
(
gQ
,
sQ
,
4
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
512
,
1
,
1
,
0
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
uint8_t
*
q_lds_read_ptr
=
reinterpret_cast
<
uint8_t
*>
(
sQ
.
data
().
get
())
+
(
tidx
%
64
)
*
16
+
(
warp_id
%
4
)
*
(
16
*
64
);
{
int
k
=
0
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
}
q_lds_read_ptr
+=
64
*
64
;
else
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
]
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
)
;
}
}
// int k = 0;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
else
{
q_lds_read_ptr
+=
64
*
64
;
int
k
=
2
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
0
);
}
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
1
);
q_lds_read_ptr
+=
64
*
64
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
2
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
3
);
{
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
);
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 2;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
auto
q_lds_read_ptr
=
sQ
.
data
().
get
()
+
(
warp_id
%
4
)
*
16
*
64
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
q_r
[
0
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
0
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
1
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
// q_lds_read_ptr += 64 * 64;
q_r
[
2
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
2
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
q_r
[
3
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
3
*
64
*
64
,
3
,
1
,
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
4
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
4
*
64
*
64
,
3
,
1
,
0
);
int
k
=
4
;
// q_lds_read_ptr += 64 * 64;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
q_r
[
5
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
5
*
64
*
64
,
3
,
1
,
0
);
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
6
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
6
*
64
*
64
,
3
,
1
,
0
);
int
k
=
6
;
// q_lds_read_ptr += 64 * 64;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
q_r
[
7
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
7
*
64
*
64
,
3
,
1
,
0
);
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
8
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
8
*
64
*
64
,
3
,
1
,
0
);
int
k
=
8
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
__syncthreads
();
__syncthreads
();
#endif
#endif
...
@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32
[
i
].
w
=
0.0
f
;
acco_f32
[
i
].
w
=
0.0
f
;
}
}
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
#if 1
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
v4f
accs_f32
[
2
];
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
clear
(
acc_s
);
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// asm volatile("s_barrier \n\t");
// asm volatile("s_barrier \n\t");
...
@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
// gK.data() = gK.data() + (offset_k);
#if 1
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// const int k_zero_pad = std::min(std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
const
int
k_zero_pad
=
std
::
max
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
lds_direct_copy_qkvfp8_tp1
<
false
,
false
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
if
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
||
masking_step
!=
0
)
{
k_lds_addr
|=
0x80000000
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
0
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
1
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
2
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
3
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
auto
k_lds_read_ptr
=
sK
.
data
().
get
()
+
(
warp_id
/
4
)
*
16
*
64
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
constexpr
static
int
k_read_lds_offset
=
32
*
64
;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
{
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
constexpr
static
int
k_idx
=
0
;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
1
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
constexpr
static
int
k_idx
=
2
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
3
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
constexpr
static
int
k_idx
=
4
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
5
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
constexpr
static
int
k_idx
=
6
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
7
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
8
),
tSrK_copy_view
(
_
,
_
,
8
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
constexpr
static
int
k_idx
=
8
;
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// k_lds_read_ptr += 64 * 64;
// asm volatile("s_barrier \n\t");
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
#endif
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
...
@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
}
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
#if 1
Fp8_storage
p_fp8
;
Fp8_storage
p_fp8
;
{
{
...
@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads
();
__syncthreads
();
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
{
__builtin_amdgcn_sched_barrier
(
0
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
int
lane_id
=
tidx
%
64
;
{
int
row
=
lane_id
/
4
;
{
int
col
=
lane_id
%
4
;
int
k
=
0
;
col
=
(
col
+
(
row
/
2
)
%
4
)
%
4
;
Fp8_storage
v0_0
,
v0_1
;
auto
lds_offset
=
row
*
64
+
col
*
16
+
(
warp_id
/
4
)
*
64
*
64
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))));
// if (block0() && tidx < 64)
// Fp8_storage v0_0, v0_1;
// {
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
// }
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
Val
tmp
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
}
}
{
{
int
k
=
2
;
Fp8_storage
v0_0
,
v0_1
;
Fp8_storage
v0_0
,
v0_1
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))
));
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
));
// if (block0() && tidx < 64
)
for
(
int
j
=
0
;
j
<
4
;
j
++
)
//
{
{
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0)
;
intx2_t
v
;
// float v1 = __builtin_amdgcn_cvt_f32_fp8(
v0_0.fp8_array[
0], 1)
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
]
;
// float v2 = __builtin_amdgcn_cvt_f32_fp8(
v0_
0
.fp8_array[
0], 2)
;
v
[
1
]
=
v0_
1
.
fp8_array
[
j
]
;
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3
);
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
0
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
}
// }
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
+
32
*
64
));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
+
32
*
64
));
for
(
int
j
=
0
;
j
<
4
;
j
++
)
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
{
Val
tmp
;
intx2_t
v
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
1
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
#endif
}
}
#endif
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
...
@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
if
(
NoSplit
)
{
if
(
NoSplit
)
{
constexpr
bool
Split
=
false
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
{
using
result_type
=
cutlass
::
Array
<
bfloat16_t
,
2
>
;
using
result_type
=
cutlass
::
Array
<
bfloat16_t
,
2
>
;
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
...
@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum
(
row
,
col
+
3
)
=
res1
[
1
];
gOaccum
(
row
,
col
+
3
)
=
res1
[
1
];
// col += 16;
// col += 16;
}
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
}
}
}
}
else
{
}
else
{
constexpr
bool
Split
=
true
;
constexpr
bool
Split
=
true
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
@@ -1616,15 +1622,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1616,15 +1622,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
+
(
warpid
%
4
)
*
16
;
const
int
row
=
tidx
%
16
+
(
warpid
%
4
)
*
16
;
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
// for (int n = 0; n < 32; n++)
// {
// col = (tidx % 64 / 16) * 4 + n * 16;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for
(
int
n
=
0
;
n
<
4
;
n
++
)
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
{
col
=
(
tidx
%
64
/
16
)
*
16
+
n
*
128
+
(
warp_id
/
4
)
*
64
;
col
=
(
tidx
%
64
/
16
)
*
16
+
n
*
128
+
(
warp_id
/
4
)
*
64
;
...
@@ -1658,44 +1655,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1658,44 +1655,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
w
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
w
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
w
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
w
;
}
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
}
}
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
SharedStorage
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
SharedStorage
>
__forceinline__
__device__
void
compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4
(
const
Flash_fwd_mla_params
&
params
,
__forceinline__
__device__
void
compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4
(
const
Flash_fwd_mla_params
&
params
,
...
...
csrc/extension/utils.h
View file @
98b7c697
...
@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
...
@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
extern
__device__
__attribute__
((
const
))
float
__llvm_exp2_f32
(
float
)
__asm
(
"llvm.exp2.f32"
);
extern
__device__
__attribute__
((
const
))
float
__llvm_exp2_f32
(
float
)
__asm
(
"llvm.exp2.f32"
);
__device__
inline
uint32x4_t
make_rscr
(
unsigned
char
*
ptr
,
const
int
stride
,
const
int
zero_pad
)
{
uint32x4_t
rscr
;
*
(
uint64_t
*
)
&
rscr
=
(
reinterpret_cast
<
uint64_t
>
(
ptr
));
rscr
[
2
]
=
stride
;
rscr
[
3
]
=
(
1
<<
16
)
&
0XFFFFFFFF
;
rscr
[
3
]
|=
(
zero_pad
)
<<
8
;
return
rscr
;
}
template
<
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_zero_lds
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
k_idx_
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
//0-256
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
//0-63
constexpr
int
element_size
=
1
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
//576
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
16
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
//64*16*1
int
offset_v
=-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
(
warp_id
%
4
)
*
bytes_per_warp
+
(
k_idx
)
*
64
*
128
*
element_size
+
(
warp_id
/
4
)
*
64
*
64
;
#if defined(__gfx938__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment