Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
98b7c697
Commit
98b7c697
authored
Feb 27, 2026
by
zhanghj2
Browse files
fp8 tp1性能提升
parent
24c52aee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
367 additions
and
346 deletions
+367
-346
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+311
-346
csrc/extension/utils.h
csrc/extension/utils.h
+56
-0
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
98b7c697
...
@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1060,7 +1060,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
warp_id
=
tidx
/
64
;
const
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
)
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
...
@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1100,131 +1100,63 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
intx2_t
p
[
2
];
intx2_t
p
[
2
];
int32_t
fp8_array
[
4
];
int32_t
fp8_array
[
4
];
};
};
#if 0
Fp8_storage
q_r
[
9
];
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
#if 1
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
auto
gQ_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
q_row_stride
;
Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ);
const
int
q_zero_pad
=
std
::
min
(
std
::
max
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
+
1
)
*
16
-
params
.
seqlen_q
,
0
),
16
);
Tensor tSrQ = thr_mma.partition_fragment_A(gQ);
uint32x4_t
gQ_rscr
=
make_rscr
((
unsigned
char
*
)(
gQ
.
data
().
get
()
+
gQ_offset
),
params
.
q_row_stride
,
q_zero_pad
);
Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ)));
auto
q_lds_addr
=
reinterpret_cast
<
size_t
>
(
sQ
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
)
|
0x80000000
;
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
if
(
m_block
*
kBlockM
+
((
warp_id
)
%
4
)
*
16
<
params
.
seqlen_q
)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
__syncthreads();
#else
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
gQ
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
0
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
1
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
true
>
(
gQ
,
sQ
,
3
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp1
<
false
,
false
>
(
gQ
,
sQ
,
4
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
uint8_t
*
q_lds_read_ptr
=
reinterpret_cast
<
uint8_t
*>
(
sQ
.
data
().
get
())
+
(
tidx
%
64
)
*
16
+
(
warp_id
%
4
)
*
(
16
*
64
);
{
{
int
k
=
0
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
0
,
1
,
1
,
0
,
0
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
128
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
q_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
__builtin_hcu_matrix_load_64x16_b8
(
gQ_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
q_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
}
q_lds_read_ptr
+=
64
*
64
;
else
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
]
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
)
;
}
}
// int k = 0;
}
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
else
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
{
{
q_lds_read_ptr
+=
64
*
64
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
0
);
int
k
=
2
;
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
1
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
2
);
{
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
3
);
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
lds_direct_copy_qkvfp8_zero_lds
(
gQ
,
sQ
,
4
);
}
}
q_lds_read_ptr
+=
64
*
64
;
auto
q_lds_read_ptr
=
sQ
.
data
().
get
()
+
(
warp_id
%
4
)
*
16
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
q_r
[
0
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
0
,
3
,
1
,
0
);
}
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64*64;
q_r
[
1
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
64
*
64
,
3
,
1
,
0
);
// int k = 2;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64 * 64;
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
q_r
[
2
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
2
*
64
*
64
,
3
,
1
,
0
);
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64*64;
q_r
[
3
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
3
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
4
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
4
*
64
*
64
,
3
,
1
,
0
);
int
k
=
4
;
// q_lds_read_ptr += 64 * 64;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
q_r
[
5
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
5
*
64
*
64
,
3
,
1
,
0
);
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 4;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
6
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
6
*
64
*
64
,
3
,
1
,
0
);
int
k
=
6
;
// q_lds_read_ptr += 64 * 64;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
q_r
[
7
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
7
*
64
*
64
,
3
,
1
,
0
);
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
q_lds_read_ptr
+=
64
*
64
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
+
1
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 6;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// q_lds_read_ptr += 64*64;
// q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
{
// q_lds_read_ptr += 64 * 64;
q_lds_read_ptr
+=
64
*
64
;
q_r
[
8
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
8
*
64
*
64
,
3
,
1
,
0
);
int
k
=
8
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
).
storage
=
q_lds_read_ptr
[
i
];
}
// q_lds_read_ptr += 64*64;
// int k = 8;
// intx4_t * q_lds_read_16 = reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// intx4_t * tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
__syncthreads
();
__syncthreads
();
#endif
#endif
...
@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1263,10 +1195,21 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acco_f32
[
i
].
w
=
0.0
f
;
acco_f32
[
i
].
w
=
0.0
f
;
}
}
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
#if 1
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
v4f
accs_f32
[
2
];
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
clear
(
acc_s
);
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// asm volatile("s_barrier \n\t");
// asm volatile("s_barrier \n\t");
...
@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1278,37 +1221,192 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
// gK.data() = gK.data() + (offset_k);
#if 1
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// const int k_zero_pad = std::min(std::max(n_block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
lds_direct_copy_qkvfp8_tp1
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
const
int
k_zero_pad
=
std
::
max
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
lds_direct_copy_qkvfp8_tp1
<
false
,
false
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
if
(
n_block
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
||
masking_step
!=
0
)
{
k_lds_addr
|=
0x80000000
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
256
+
128
,
1
,
1
,
0
,
0
);
k_lds_addr
+=
64
*
128
;
if
(
warp_id
<
4
)
{
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
512
,
1
,
1
,
0
,
0
);
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
}
else
{
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
0
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
1
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
2
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
3
);
lds_direct_copy_qkvfp8_zero_lds
(
gK
,
sK
,
4
);
}
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
auto
k_lds_read_ptr
=
sK
.
data
().
get
()
+
(
warp_id
/
4
)
*
16
*
64
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
constexpr
static
int
k_read_lds_offset
=
32
*
64
;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
{
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
constexpr
static
int
k_idx
=
0
;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
// k_lds_read_ptr += k_idx * 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
1
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
// if (block0())
// {
// printf(" %x %x %x %x %x %x %x %x \n", q_r[k_idx].fp8_array[0], q_r[k_idx].fp8_array[1], q_r[k_idx].fp8_array[2], q_r[k_idx].fp8_array[3], k_data.fp8_array[0], k_data.fp8_array[1], k_data.fp8_array[2], k_data.fp8_array[3]);
// }
}
#if 1
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
constexpr
static
int
k_idx
=
2
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
3
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
constexpr
static
int
k_idx
=
4
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
5
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
constexpr
static
int
k_idx
=
6
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
// k_lds_read_ptr += 64 * 64;
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
{
constexpr
static
int
k_idx
=
7
;
// k_lds_read_ptr += 64 * 64;
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
8
),
tSrK_copy_view
(
_
,
_
,
8
));
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
constexpr
static
int
k_idx
=
8
;
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// k_lds_read_ptr += 64 * 64;
// asm volatile("s_barrier \n\t");
Fp8_storage
k_data
;
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
,
3
,
1
,
0
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
0
],
true
,
false
);
k_data
.
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
k_lds_read_ptr
),
k_idx
*
64
*
64
+
k_read_lds_offset
,
3
,
1
,
0
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
0
],
k_data
.
p
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
q_r
[
k_idx
].
p
[
1
],
k_data
.
p
[
1
],
accs_f32
[
1
],
true
,
false
);
}
#endif
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
...
@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1333,20 +1431,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
}
}
// asm volatile("s_barrier \n\t");
// if (block0() && tidx < 64)
// {
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// // acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// // acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// // acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// // );
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// acc_s(4), acc_s(5), acc_s(6), acc_s(7)
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// );
// }
#if 1
#if 1
Fp8_storage
p_fp8
;
Fp8_storage
p_fp8
;
{
{
...
@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1371,83 +1455,53 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
__syncthreads
();
__syncthreads
();
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
p_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
+
(
warp_id
%
4
)
*
16
*
64
]));
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
{
int
lane_id
=
tidx
%
64
;
__builtin_amdgcn_sched_barrier
(
0
);
int
row
=
lane_id
/
4
;
int
col
=
lane_id
%
4
;
col
=
(
col
+
(
row
/
2
)
%
4
)
%
4
;
auto
lds_offset
=
row
*
64
+
col
*
16
+
(
warp_id
/
4
)
*
64
*
64
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
// Fp8_storage v0_0, v0_1;
// v0_0.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(A_smem + lds_offset));
// v0_1.data = __builtin_hcu_ds_read_m64x16_u8_alt4((__attribute__((address_space(3))) int*)(sV.data().get() + lds_offset + 16 * 64));
// acco_f32[i * 4 + j] = __builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts(p_fp8.p[k/2], tmp.val_to_mmac, acco_f32[i * 4 + j], true, false);
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
{
{
Fp8_storage
v0_0
,
v0_1
;
int
k
=
0
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
));
Fp8_storage
v0_0
,
v0_1
;
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
));
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
// }
{
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
0
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
Val
tmp
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
}
}
}
{
int
k
=
2
;
Fp8_storage
v0_0
,
v0_1
;
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
))));
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
&
(
tOsVt
(
0
,
i
,
k
+
1
))));
// if (block0() && tidx < 64)
// {
// float v0 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 0);
// float v1 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 1);
// float v2 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 2);
// float v3 = __builtin_amdgcn_cvt_f32_fp8(v0_0.fp8_array[0], 3);
// printf("tid = %d %.3f %.3f %.3f %.3f %x \n", tidx, v0, v1, v2, v3, v0_0.fp8_array[0]);
v0_0
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
n
*
64
*
128
+
32
*
64
));
// }
v0_1
.
data
=
__builtin_hcu_ds_read_m64x16_u8_alt4
((
__attribute__
((
address_space
(
3
)))
int
*
)(
sV
.
data
().
get
()
+
lds_offset
+
16
*
64
+
n
*
64
*
128
+
32
*
64
));
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
intx2_t
v
;
v
[
0
]
=
v0_0
.
fp8_array
[
j
];
v
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
n
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
1
],
v
,
acco_f32
[
n
*
4
+
j
],
true
,
false
);
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
Val
tmp
;
tmp
.
data
[
0
]
=
v0_0
.
fp8_array
[
j
];
tmp
.
data
[
1
]
=
v0_1
.
fp8_array
[
j
];
acco_f32
[
i
*
4
+
j
]
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
p_fp8
.
p
[
k
/
2
],
tmp
.
val_to_mmac
,
acco_f32
[
i
*
4
+
j
],
true
,
false
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// if (block0())
// {
// printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w,
// acco_f32[1].x, acco_f32[1].y, acco_f32[1].z, acco_f32[1].w,
// acco_f32[2].x, acco_f32[2].y, acco_f32[2].z, acco_f32[2].w,
// acco_f32[3].x, acco_f32[3].y, acco_f32[3].z, acco_f32[3].w
// );
// // printf("n_block = %d tidx = %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n ", n_block, tidx, acc_s(0), acc_s(1), acc_s(2), acc_s(3),
// // acc_s(4), acc_s(5), acc_s(6), acc_s(7),
// // acc_s(8), acc_s(9), acc_s(10), acc_s(11),
// // acc_s(12), acc_s(13), acc_s(14), acc_s(15)
// // );
// }
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
#endif
}
}
#endif
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
...
@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1458,9 +1512,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
if
(
NoSplit
)
{
if
(
NoSplit
)
{
constexpr
bool
Split
=
false
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1482,12 +1533,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
}
}
// if (tidx == 1)
// {
// printf(" %.4f %.4f %.4f %.4f \n ", acco_f32[0].x, acco_f32[0].y, acco_f32[0].z, acco_f32[0].w);
// }
{
{
using
result_type
=
cutlass
::
Array
<
bfloat16_t
,
2
>
;
using
result_type
=
cutlass
::
Array
<
bfloat16_t
,
2
>
;
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
...
@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1547,49 +1592,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
gOaccum
(
row
,
col
+
3
)
=
res1
[
1
];
gOaccum
(
row
,
col
+
3
)
=
res1
[
1
];
// col += 16;
// col += 16;
}
}
// for (int j = 0; j < 4; j++)
// {
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].x, 0, acco_f32[n * 4 + j].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n* 4 + j].z, 0, acco_f32[n * 4 + j].w, 0);
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// col += 16;
// }
}
}
// for (int n = 0; n < 8; n++)
// {
// using result_type = cutlass::Array<bfloat16_t, 2>;
// auto d0 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].x, 0, acco_f32[n].y, 0);
// auto d1 = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[n].z, 0, acco_f32[n].w, 0);
// col = (tidx % 64 / 16) * 4 + n * 64;
// auto res0 = reinterpret_cast<result_type const &>(d0);
// auto res1 = reinterpret_cast<result_type const &>(d1);
// gOaccum(row, col) = res0[0];
// gOaccum(row, col + 1) = res0[1];
// gOaccum(row, col + 2) = res1[0];
// gOaccum(row, col + 3) = res1[1];
// }
}
}
}
}
}
}
}
else
{
}
else
{
constexpr
bool
Split
=
true
;
constexpr
bool
Split
=
true
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
...
@@ -1610,92 +1616,51 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1610,92 +1616,51 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
}
}
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
int
col
=
0
;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
+
(
warpid
%
4
)
*
16
;
const
int
row
=
tidx
%
16
+
(
warpid
%
4
)
*
16
;
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
// for (int n = 0; n < 32; n++)
for
(
int
n
=
0
;
n
<
4
;
n
++
)
// {
{
// col = (tidx % 64 / 16) * 4 + n * 16;
col
=
(
tidx
%
64
/
16
)
*
16
+
n
*
128
+
(
warp_id
/
4
)
*
64
;
// gOaccum(row, col) = acco_f32[n].x;
// gOaccum(row, col + 1) = acco_f32[n].y;
// gOaccum(row, col + 2) = acco_f32[n].z;
// gOaccum(row, col + 3) = acco_f32[n].w;
// }
for
(
int
n
=
0
;
n
<
4
;
n
++
)
{
{
col
=
(
tidx
%
64
/
16
)
*
16
+
n
*
128
+
(
warp_id
/
4
)
*
64
;
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
x
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
x
;
{
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
x
;
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
x
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
x
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
x
;
col
+=
4
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
x
;
}
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
x
;
{
col
+=
4
;
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
y
;
}
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
y
;
{
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
y
;
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
y
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
y
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
y
;
col
+=
4
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
y
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
y
;
col
+=
4
;
}
{
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
z
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
z
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
z
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
z
;
col
+=
4
;
}
{
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
w
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
w
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
w
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
w
;
}
// for (int j = 0; j < 4; j++) {
// gOaccum(row, col) = acco_f32[n * 4 + j].x;
// gOaccum(row, col + 1) = acco_f32[n * 4 + j].y;
// gOaccum(row, col + 2) = acco_f32[n * 4 + j].z;
// gOaccum(row, col + 3) = acco_f32[n * 4 + j].w;
// col += 16;
// }
}
}
{
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
z
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
z
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
z
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
z
;
col
+=
4
;
}
{
gOaccum
(
row
,
col
)
=
acco_f32
[
n
*
4
+
0
].
w
;
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
n
*
4
+
1
].
w
;
gOaccum
(
row
,
col
+
2
)
=
acco_f32
[
n
*
4
+
2
].
w
;
gOaccum
(
row
,
col
+
3
)
=
acco_f32
[
n
*
4
+
3
].
w
;
}
}
}
}
}
}
}
}
// Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{});
// for (int n = 0; n < 8; n++)
// {
// acc_o(0, 0, n) = acco_f32[n * 2].x;
// acc_o(1, 0, n) = acco_f32[n * 2].y;
// acc_o(2, 0, n) = acco_f32[n * 2].z;
// acc_o(3, 0, n) = acco_f32[n * 2].w;
// acc_o(4, 0, n) = acco_f32[n * 2 + 1].x;
// acc_o(5, 0, n) = acco_f32[n * 2 + 1].y;
// acc_o(6, 0, n) = acco_f32[n * 2 + 1].z;
// acc_o(7, 0, n) = acco_f32[n * 2 + 1].w;
// }
// if (NoSplit)
// store_float8<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
// else
// store_float8<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, acc_o, softmax, descale_k, scale_softmax);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
SharedStorage
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
SharedStorage
>
__forceinline__
__device__
void
compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4
(
const
Flash_fwd_mla_params
&
params
,
__forceinline__
__device__
void
compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP4
(
const
Flash_fwd_mla_params
&
params
,
...
...
csrc/extension/utils.h
View file @
98b7c697
...
@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
...
@@ -2692,6 +2692,62 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, int
extern
__device__
__attribute__
((
const
))
float
__llvm_exp2_f32
(
float
)
__asm
(
"llvm.exp2.f32"
);
extern
__device__
__attribute__
((
const
))
float
__llvm_exp2_f32
(
float
)
__asm
(
"llvm.exp2.f32"
);
__device__
inline
uint32x4_t
make_rscr
(
unsigned
char
*
ptr
,
const
int
stride
,
const
int
zero_pad
)
{
uint32x4_t
rscr
;
*
(
uint64_t
*
)
&
rscr
=
(
reinterpret_cast
<
uint64_t
>
(
ptr
));
rscr
[
2
]
=
stride
;
rscr
[
3
]
=
(
1
<<
16
)
&
0XFFFFFFFF
;
rscr
[
3
]
|=
(
zero_pad
)
<<
8
;
return
rscr
;
}
template
<
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_qkvfp8_zero_lds
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
int
k_idx_
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
//0-256
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
//0-63
constexpr
int
element_size
=
1
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
//576
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
16
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
//64*16*1
int
offset_v
=-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
(
warp_id
%
4
)
*
bytes_per_warp
+
(
k_idx
)
*
64
*
128
*
element_size
+
(
warp_id
/
4
)
*
64
*
64
;
#if defined(__gfx938__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment