Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
6d3ed1da
Commit
6d3ed1da
authored
Mar 02, 2026
by
zhanghj2
Browse files
优化nmz tp8
parent
2ff340aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
120 additions
and
73 deletions
+120
-73
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+86
-69
csrc/extension/utils.h
csrc/extension/utils.h
+34
-4
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
6d3ed1da
...
@@ -563,7 +563,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -563,7 +563,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
lane_idx
=
tidx
%
64
;
const
int
lane_idx
=
tidx
%
64
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
);
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
);
const
index_t
row_offset_k
=
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_k
=
0
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*576
make_stride
(
params
.
k_row_stride
,
_1
{}));
//64*576
...
@@ -863,6 +863,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -863,6 +863,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
#if 1
#if 1
uint8_t
*
kv_lds_write_ptr_base
=
reinterpret_cast
<
uint8_t
*>
(
&
tSsK
(
0
,
0
,
0
));
v4f
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
;
v4f
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
;
c0_0
.
x
=
0.0
f
;
c0_0
.
y
=
0.0
f
;
c0_0
.
z
=
0.0
f
;
c0_0
.
w
=
0.0
f
;
c0_0
.
x
=
0.0
f
;
c0_0
.
y
=
0.0
f
;
c0_0
.
z
=
0.0
f
;
c0_0
.
w
=
0.0
f
;
c0_1
.
x
=
0.0
f
;
c0_1
.
y
=
0.0
f
;
c0_1
.
z
=
0.0
f
;
c0_1
.
w
=
0.0
f
;
c0_1
.
x
=
0.0
f
;
c0_1
.
y
=
0.0
f
;
c0_1
.
z
=
0.0
f
;
c0_1
.
w
=
0.0
f
;
...
@@ -879,6 +881,40 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -879,6 +881,40 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
struct
IsMaskBlock
{};
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
struct
IsNoMaskBlock
{};
const
auto
gK_data
=
gK
.
data
();
Fp8_storage
kv_data
[
9
];
{
int
cur_block_table
;
cur_block_table
=
block_table
[
n_block
];
index_t
offset_k
;
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK_data
+
(
offset_k
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
buffer_load_copy_qkvfp8
<
false
,
true
,
false
,
false
>
(
gK
,
kv_data
[
i
].
data
,
i
,
params
.
k_row_stride
,
0
,
seqlen_k
-
n_block
*
kBlockN
);
}
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
false
>
(
gK
,
kv_data
[
8
].
data
,
8
,
params
.
k_row_stride
,
0
,
seqlen_k
-
n_block
*
kBlockN
);
// __syncthreads();
uint8_t
*
kv_lds_write_ptr
=
kv_lds_write_ptr_base
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
i
].
data
;
kv_lds_write_ptr
+=
64
*
64
;
}
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 1, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 2, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 3, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 4, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 5, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
// lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
// buffer_load_copy_qkvfp8<false, true, true, false>(gK, kv_data[8].data, 8, params.k_row_stride, 0, seqlen_k - n_block * kBlockN);
}
// if (block0())
// {
// printf("threadIdx.x %d kv_lds_write_ptr_base = %p\n ", threadIdx.x, kv_lds_write_ptr_base);
// }
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
...
@@ -889,64 +925,30 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -889,64 +925,30 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
clear
(
acc_s
);
clear
(
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
int
cur_block_table
;
__syncthreads
();
const
int
*
cur_block_table_ptr
=
block_table
+
block_idx
;
// cur_block_table = block_table[block_idx - 1];
for
(
int
i
=
0
;
i
<
7
;
i
++
)
{
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
i
),
tSrK_copy_view
(
_
,
_
,
i
));
"s_waitcnt lgkmcnt(0)
\n\t
"
:
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
"+s"
(
cur_block_table_ptr
),
}
"=s"
(
cur_block_table
));
// cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
// cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
#if 1
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
5
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
constexpr
static
int
BUFFER_SIZE
=
1
;
uint128_t
buffer
[
BUFFER_SIZE
];
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
8
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
block_idx
*
kBlockN
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(7)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(6)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(5)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
Fp8_storage
v3_0
,
v3_1
;
Fp8_storage
v3_0
,
v3_1
;
__ds_read_m32x32_row_col_rrow
<
3
,
0
,
3
>
(
tOsVt
,
v3_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
3
,
0
,
3
>
(
tOsVt
,
v3_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
3
,
1
,
3
>
(
tOsVt
,
v3_1
.
data
);
__ds_read_m32x32_row_col_rrow
<
3
,
1
,
3
>
(
tOsVt
,
v3_1
.
data
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
"
);
{
buffer_to_tensor
(
buffer
[
0
],
tSrK
,
8
);
intx4_t
*
d
=
reinterpret_cast
<
intx4_t
*>
(
&
tSrK
(
0
,
0
,
8
));
*
d
=
kv_data
[
8
].
data
;
}
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
#else
// if (thread0()) {
#endif
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// }
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
...
@@ -962,27 +964,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -962,27 +964,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
}
}
}
}
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
// We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1)
// {
// softmax.template softmax_rescale_o_fp8</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, sRow_max_reduce_buffer, scale_softmax_log2, c0_0, c0_1, c1_0, c1_1, c2_0, c2_1, c3_0, c3_1);
// }
// else
{
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
}
// Tensor rP = flash::convert_type<Element>(acc_s);
Fp8_storage
data_fp8
;
Fp8_storage
data_fp8
;
// convert_layout_acc_Aregs_fp8(tiled_mma, tiled_mma_o, rP, sP, data_fp8.data);
{
{
int
tid
=
threadIdx
.
x
%
64
;
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
...
@@ -996,6 +980,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -996,6 +980,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
data_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
]));
data_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
]));
}
}
if
(
block_idx
>
n_block_min
)
{
int
cur_block_table
;
const
int
*
cur_block_table_ptr
;
cur_block_table
=
block_table
[
block_idx
-
1
];
index_t
offset_k
;
// cur_block_table_ptr = block_table + block_idx;
// asm volatile("s_load_dword %1, %0, 0x0\n\t"
// "s_waitcnt lgkmcnt(0)\n\t":
// "+s"(cur_block_table_ptr),
// "=s"(cur_block_table));
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK_data
+
(
offset_k
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
buffer_load_copy_qkvfp8
<
true
,
true
,
false
,
false
>
(
gK
,
kv_data
[
i
].
data
,
i
,
params
.
k_row_stride
,
0
);
}
buffer_load_copy_qkvfp8
<
true
,
true
,
true
,
false
>
(
gK
,
kv_data
[
8
].
data
,
8
,
params
.
k_row_stride
,
0
);
}
{
{
Fp8_storage
v0_0
,
v0_1
;
Fp8_storage
v0_0
,
v0_1
;
Fp8_storage
v1_0
,
v1_1
;
Fp8_storage
v1_0
,
v1_1
;
...
@@ -1034,6 +1038,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -1034,6 +1038,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c2_1
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
1
],
v2_1
.
p
[
1
],
c2_1
,
true
,
false
);
c2_1
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
1
],
v2_1
.
p
[
1
],
c2_1
,
true
,
false
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n ", c0_0.x, c0_0.y, c0_0.z, c0_0.w);
// }
}
if
(
block_idx
>
n_block_min
)
{
__syncthreads
();
uint8_t
*
kv_lds_write_ptr
=
kv_lds_write_ptr_base
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
i
].
data
;
kv_lds_write_ptr
+=
64
*
64
;
}
}
}
};
};
...
...
csrc/extension/utils.h
View file @
6d3ed1da
...
@@ -570,7 +570,7 @@ CUTE_HOST_DEVICE
...
@@ -570,7 +570,7 @@ CUTE_HOST_DEVICE
void
void
buffer_load_copy_qkvfp8
(
buffer_load_copy_qkvfp8
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
u
int
128
_t
&
dst
,
int
x4
_t
&
dst
,
int
k_idx_
,
const
int
row_stride
,
int
k_idx_
,
const
int
row_stride
,
int
offset_k
,
int
offset_k
,
const
int
max_MN
=
0
)
const
int
max_MN
=
0
)
...
@@ -615,11 +615,41 @@ buffer_load_copy_qkvfp8(
...
@@ -615,11 +615,41 @@ buffer_load_copy_qkvfp8(
);
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
*
reinterpret_cast
<
uint128_t
*>
(
&
res
);
}
}
}
else
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
//0-256
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
//0-63
constexpr
int
element_size
=
1
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
//576
const
int
offset_s
=
0
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
constexpr
int
elements_per_thread
=
16
;
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
lane
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
+
(
warp_id
*
16
);
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment