Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
79096f6b
Commit
79096f6b
authored
Mar 17, 2026
by
zhanghj2
Browse files
使用__builtin_hcu_ds_read_m32x32_i8_alt2指令
parent
4b3bcb50
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
126 additions
and
41 deletions
+126
-41
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+118
-40
csrc/extension/utils.h
csrc/extension/utils.h
+8
-1
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
79096f6b
...
@@ -472,23 +472,23 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
...
@@ -472,23 +472,23 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8
<
/*Is_dropout=*/
false
,
Split
>(
tOrO
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8
<
/*Is_dropout=*/
false
,
Split
>(
tOrO
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
using
ElementO
=
std
::
conditional_t
<!
Split
,
Element
,
ElementAccum
>
;
using
ElementO
=
std
::
conditional_t
<!
Split
,
Element
,
ElementAccum
>
;
Tensor
sOaccum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementO
*>
(
shared_storage
.
smem_o
.
data
())),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
//
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
//
//
Partition sO to match the accumulator partitioning
using
SmemTiledCopyO
=
std
::
conditional_t
<
//
using SmemTiledCopyO = std::conditional_t<
!
Split
,
//
!Split,
typename
Kernel_traits
::
SmemCopyAtomO
,
//
typename Kernel_traits::SmemCopyAtomO,
typename
Kernel_traits
::
SmemCopyAtomOaccum
//
typename Kernel_traits::SmemCopyAtomOaccum
>
;
//
>;
auto
smem_tiled_copy_Oaccum
=
make_tiled_copy_C
(
SmemTiledCopyO
{},
tiled_mma_o
);
//
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto
smem_thr_copy_Oaccum
=
smem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
//
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
tOrO
);
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
tOrO
);
Tensor
taccOrOaccum
=
smem_thr_copy_Oaccum
.
retile_S
(
rO
);
// ((Atom,AtomNum), MMA_M, MMA_N)
//
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsOaccum
=
smem_thr_copy_Oaccum
.
partition_D
(
sOaccum
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
//
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// __syncthreads();
//
//
__syncthreads();
cute
::
copy
(
smem_tiled_copy_Oaccum
,
taccOrOaccum
,
taccOsOaccum
);
//
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_v
;
const
index_t
row_offset_oaccum
=
(((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_v
;
...
@@ -501,20 +501,20 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
...
@@ -501,20 +501,20 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
Split
?
row_offset_lseaccum
:
row_offset_lse
)),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
Split
?
row_offset_lseaccum
:
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
using
GmemTiledCopyO
=
std
::
conditional_t
<!
Split
,
typename
Kernel_traits
::
GmemTiledCopyO
,
typename
Kernel_traits
::
GmemTiledCopyOaccum
>
;
//
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO
gmem_tiled_copy_Oaccum
;
//
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
//
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor
tOsOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
sOaccum
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
//
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_D
(
gOaccum
);
//
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads
();
//
__syncthreads();
// if (tidx >= kNThreadsS) { return; }
// if (tidx >= kNThreadsS) { return; }
Tensor
tOrOaccum
=
make_tensor
<
ElementO
>
(
shape
(
tOgOaccum
));
//
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute
::
copy
(
gmem_tiled_copy_Oaccum
,
tOsOaccum
,
tOrOaccum
);
//
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
// ((MMA=4, X), MMA_M, MMA_K=1)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
// ((MMA=4, X), MMA_M, MMA_K=1)
...
@@ -528,15 +528,47 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
...
@@ -528,15 +528,47 @@ __forceinline__ __device__ void store_float8(const Flash_fwd_mla_params ¶ms,
}
}
}
}
// Construct identity layout for sO
// // Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sOaccum
),
size
<
1
>
(
sOaccum
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
// // Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_Oaccum
.
partition_D
(
cO
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
// Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgOaccum
)));
// Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem
// // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
true
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
// flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
params
.
seqlen_q
-
m_block
*
kBlockM
// gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
);
// );
{
int
tidx
=
threadIdx
.
x
;
int
col
=
0
;
for
(
int
m
=
0
;
m
<
size
<
1
>
(
rO
);
m
++
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
m
,
0
));
if
(
row
<
params
.
seqlen_q
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
rO
);
n
++
)
{
// col = (tidx % 64 / 16) * 4 + (tidx / 64) * 32 + n * 128;
// for (int ei = 0; ei < 8; ei += 4) {
// gOaccum(row, col) = rO(ei, m, n);
// gOaccum(row, col + 1) = rO(ei + 1, m, n);
// gOaccum(row, col + 2) = rO(ei + 2, m, n);
// gOaccum(row, col + 3) = rO(ei + 3, m, n);
// col += 16;
// }
col
=
(
tidx
%
64
/
16
)
*
8
+
(
tidx
/
64
)
*
32
+
n
*
128
;
gOaccum
(
row
,
col
)
=
rO
(
0
,
m
,
n
);
gOaccum
(
row
,
col
+
1
)
=
rO
(
1
,
m
,
n
);
gOaccum
(
row
,
col
+
2
)
=
rO
(
2
,
m
,
n
);
gOaccum
(
row
,
col
+
3
)
=
rO
(
3
,
m
,
n
);
gOaccum
(
row
,
col
+
4
)
=
rO
(
4
,
m
,
n
);
gOaccum
(
row
,
col
+
5
)
=
rO
(
5
,
m
,
n
);
gOaccum
(
row
,
col
+
6
)
=
rO
(
6
,
m
,
n
);
gOaccum
(
row
,
col
+
7
)
=
rO
(
7
,
m
,
n
);
}
}
}
}
}
}
...
@@ -963,11 +995,15 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -963,11 +995,15 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
int32_t
result
;
int32_t
result
;
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
0
),
acc_s
(
1
),
result
,
false
);
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
0
),
acc_s
(
1
),
result
,
false
);
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
2
),
acc_s
(
3
),
result
,
true
);
result
=
__builtin_hcu_cvt_pk_fp8_f32
(
acc_s
(
2
),
acc_s
(
3
),
result
,
true
);
int32_t
*
lds_ptr
=
reinterpret_cast
<
int32_t
*>
(
&
sP
[
(
tid
)
*
16
+
warp_id
*
4
]);
int32_t
*
lds_ptr
=
reinterpret_cast
<
int32_t
*>
(
&
(
sP
[
(
tid
%
16
)
*
16
+
((
tid
/
16
)
%
2
)
*
4
+
(
tid
/
32
)
*
(
16
*
16
)
+
(
warp_id
%
2
)
*
(
16
*
32
)
+
(
warp_id
/
2
)
*
(
8
)
+
0
]));
//
int32_t* lds_ptr = reinterpret_cast<int32_t*>(&(sP[ (tid % 16) * 16 + ((tid / 16) % 2 ) * 4 + (tid / 32) * (16 * 16) + (warp_id % 2) * (16 * 32) + (warp_id / 2) * (8) + 0]));
*
lds_ptr
=
result
;
*
lds_ptr
=
result
;
__syncthreads
();
__syncthreads
();
data_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
]));
data_fp8
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
&
(
sP
[
tid
*
16
]));
// data_fp8.bf16[0] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4]));
// data_fp8.bf16[1] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 64 * 4]));
// data_fp8.bf16[2] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 2 * 64 * 4]));
// data_fp8.bf16[3] = *reinterpret_cast<int*>(&(sP[(tid % 64) * 4 + 3 * 64 * 4]));
}
}
if
(
block_idx
>
n_block_min
)
{
if
(
block_idx
>
n_block_min
)
{
...
@@ -999,6 +1035,28 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -999,6 +1035,28 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
__ds_read_m32x32_row_col_rrow
<
0
,
0
,
0
>
(
tOsVt
,
v0_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
0
,
0
,
0
>
(
tOsVt
,
v0_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
1
,
0
,
1
>
(
tOsVt
,
v1_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
1
,
0
,
1
>
(
tOsVt
,
v1_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
2
,
0
,
2
>
(
tOsVt
,
v2_0
.
data
);
__ds_read_m32x32_row_col_rrow
<
2
,
0
,
2
>
(
tOsVt
,
v2_0
.
data
);
// if (block0() && tidx < 64) {
// auto res0 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[0], false);
// auto res1 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[0], true);
// auto res2 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[1], false);
// auto res3 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[1], true);
// auto res4 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[2], false);
// auto res5 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[2], true);
// auto res6 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[3], false);
// auto res7 = __builtin_amdgcn_cvt_pk_f32_fp8(v0_0.bf16[3], true);
// printf(" %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f \n", tidx, res0[0], res0[1],
// res1[0], res1[1],
// res2[0], res2[1],
// res3[0], res3[1],
// res4[0], res4[1],
// res5[0], res5[1],
// res6[0], res6[1],
// res7[0], res7[1]
// );
// }
c3_0
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
0
],
v3_0
.
p
[
0
],
c3_0
,
true
,
false
);
c3_0
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
0
],
v3_0
.
p
[
0
],
c3_0
,
true
,
false
);
c3_1
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
0
],
v3_0
.
p
[
1
],
c3_1
,
true
,
false
);
c3_1
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8_lit_lts
(
data_fp8
.
p
[
0
],
v3_0
.
p
[
1
],
c3_1
,
true
,
false
);
...
@@ -1062,14 +1120,34 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -1062,14 +1120,34 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
#endif
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
acc_o
(
0
,
0
,
0
)
=
c0_0
.
x
;
acc_o
(
1
,
0
,
0
)
=
c0_0
.
y
;
acc_o
(
2
,
0
,
0
)
=
c0_0
.
z
;
acc_o
(
3
,
0
,
0
)
=
c0_0
.
w
;
acc_o
(
0
,
0
,
0
)
=
c0_0
.
x
;
acc_o
(
1
,
0
,
0
)
=
c0_1
.
x
;
acc_o
(
4
,
0
,
0
)
=
c0_1
.
x
;
acc_o
(
5
,
0
,
0
)
=
c0_1
.
y
;
acc_o
(
6
,
0
,
0
)
=
c0_1
.
z
;
acc_o
(
7
,
0
,
0
)
=
c0_1
.
w
;
acc_o
(
2
,
0
,
0
)
=
c0_0
.
y
;
acc_o
(
3
,
0
,
0
)
=
c0_1
.
y
;
acc_o
(
0
,
0
,
1
)
=
c1_0
.
x
;
acc_o
(
1
,
0
,
1
)
=
c1_0
.
y
;
acc_o
(
2
,
0
,
1
)
=
c1_0
.
z
;
acc_o
(
3
,
0
,
1
)
=
c1_0
.
w
;
acc_o
(
4
,
0
,
0
)
=
c0_0
.
z
;
acc_o
(
5
,
0
,
0
)
=
c0_1
.
z
;
acc_o
(
4
,
0
,
1
)
=
c1_1
.
x
;
acc_o
(
5
,
0
,
1
)
=
c1_1
.
y
;
acc_o
(
6
,
0
,
1
)
=
c1_1
.
z
;
acc_o
(
7
,
0
,
1
)
=
c1_1
.
w
;
acc_o
(
6
,
0
,
0
)
=
c0_0
.
w
;
acc_o
(
7
,
0
,
0
)
=
c0_1
.
w
;
acc_o
(
0
,
0
,
2
)
=
c2_0
.
x
;
acc_o
(
1
,
0
,
2
)
=
c2_0
.
y
;
acc_o
(
2
,
0
,
2
)
=
c2_0
.
z
;
acc_o
(
3
,
0
,
2
)
=
c2_0
.
w
;
acc_o
(
4
,
0
,
2
)
=
c2_1
.
x
;
acc_o
(
5
,
0
,
2
)
=
c2_1
.
y
;
acc_o
(
6
,
0
,
2
)
=
c2_1
.
z
;
acc_o
(
7
,
0
,
2
)
=
c2_1
.
w
;
acc_o
(
0
,
0
,
1
)
=
c1_0
.
x
;
acc_o
(
1
,
0
,
1
)
=
c1_1
.
x
;
acc_o
(
0
,
0
,
3
)
=
c3_0
.
x
;
acc_o
(
1
,
0
,
3
)
=
c3_0
.
y
;
acc_o
(
2
,
0
,
3
)
=
c3_0
.
z
;
acc_o
(
3
,
0
,
3
)
=
c3_0
.
w
;
acc_o
(
2
,
0
,
1
)
=
c1_0
.
y
;
acc_o
(
3
,
0
,
1
)
=
c1_1
.
y
;
acc_o
(
4
,
0
,
3
)
=
c3_1
.
x
;
acc_o
(
5
,
0
,
3
)
=
c3_1
.
y
;
acc_o
(
6
,
0
,
3
)
=
c3_1
.
z
;
acc_o
(
7
,
0
,
3
)
=
c3_1
.
w
;
acc_o
(
4
,
0
,
1
)
=
c1_0
.
z
;
acc_o
(
5
,
0
,
1
)
=
c1_1
.
z
;
acc_o
(
6
,
0
,
1
)
=
c1_0
.
w
;
acc_o
(
7
,
0
,
1
)
=
c1_1
.
w
;
acc_o
(
0
,
0
,
2
)
=
c2_0
.
x
;
acc_o
(
1
,
0
,
2
)
=
c2_1
.
x
;
acc_o
(
2
,
0
,
2
)
=
c2_0
.
y
;
acc_o
(
3
,
0
,
2
)
=
c2_1
.
y
;
acc_o
(
4
,
0
,
2
)
=
c2_0
.
z
;
acc_o
(
5
,
0
,
2
)
=
c2_1
.
z
;
acc_o
(
6
,
0
,
2
)
=
c2_0
.
w
;
acc_o
(
7
,
0
,
2
)
=
c2_1
.
w
;
acc_o
(
0
,
0
,
3
)
=
c3_0
.
x
;
acc_o
(
1
,
0
,
3
)
=
c3_1
.
x
;
acc_o
(
2
,
0
,
3
)
=
c3_0
.
y
;
acc_o
(
3
,
0
,
3
)
=
c3_1
.
y
;
acc_o
(
4
,
0
,
3
)
=
c3_0
.
z
;
acc_o
(
5
,
0
,
3
)
=
c3_1
.
z
;
acc_o
(
6
,
0
,
3
)
=
c3_0
.
w
;
acc_o
(
7
,
0
,
3
)
=
c3_1
.
w
;
// acc_o(0, 0, 0) = c0_0.x; acc_o(1, 0, 0) = c0_0.y; acc_o(2, 0, 0) = c0_0.z; acc_o(3, 0, 0) = c0_0.w;
// acc_o(4, 0, 0) = c0_1.x; acc_o(5, 0, 0) = c0_1.y; acc_o(6, 0, 0) = c0_1.z; acc_o(7, 0, 0) = c0_1.w;
// acc_o(0, 0, 1) = c1_0.x; acc_o(1, 0, 1) = c1_0.y; acc_o(2, 0, 1) = c1_0.z; acc_o(3, 0, 1) = c1_0.w;
// acc_o(4, 0, 1) = c1_1.x; acc_o(5, 0, 1) = c1_1.y; acc_o(6, 0, 1) = c1_1.z; acc_o(7, 0, 1) = c1_1.w;
// acc_o(0, 0, 2) = c2_0.x; acc_o(1, 0, 2) = c2_0.y; acc_o(2, 0, 2) = c2_0.z; acc_o(3, 0, 2) = c2_0.w;
// acc_o(4, 0, 2) = c2_1.x; acc_o(5, 0, 2) = c2_1.y; acc_o(6, 0, 2) = c2_1.z; acc_o(7, 0, 2) = c2_1.w;
// acc_o(0, 0, 3) = c3_0.x; acc_o(1, 0, 3) = c3_0.y; acc_o(2, 0, 3) = c3_0.z; acc_o(3, 0, 3) = c3_0.w;
// acc_o(4, 0, 3) = c3_1.x; acc_o(5, 0, 3) = c3_1.y; acc_o(6, 0, 3) = c3_1.z; acc_o(7, 0, 3) = c3_1.w;
if
(
NoSplit
)
if
(
NoSplit
)
store_float8
<
Kernel_traits
,
false
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
shared_storage
,
acc_o
,
softmax
,
descale_k
,
scale_softmax
);
store_float8
<
Kernel_traits
,
false
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
shared_storage
,
acc_o
,
softmax
,
descale_k
,
scale_softmax
);
...
...
csrc/extension/utils.h
View file @
79096f6b
...
@@ -2709,13 +2709,20 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Ten
...
@@ -2709,13 +2709,20 @@ __forceinline__ __device__ void __ds_read_m32x32_row_col_rrow(Tensor0& src, Ten
template
<
int
row
,
int
col
,
int
r_row
,
typename
Tensor0
>
template
<
int
row
,
int
col
,
int
r_row
,
typename
Tensor0
>
__forceinline__
__device__
void
__ds_read_m32x32_row_col_rrow
(
Tensor0
&
src
,
intx4_t
&
dst
)
__forceinline__
__device__
void
__ds_read_m32x32_row_col_rrow
(
Tensor0
&
src
,
intx4_t
&
dst
)
{
{
#if 0
auto lds = reinterpret_cast<int *>(src.data().get());
auto lds = reinterpret_cast<int *>(src.data().get());
auto layout = src.layout();
auto layout = src.layout();
constexpr short offset = layout(0, row, col) * 1;
constexpr short offset = layout(0, row, col) * 1;
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
auto d = __builtin_amdgcn_ds_read_m32x32u8((__attribute__((address_space(3))) int*)(lds), offset);
dst = d;
dst = d;
#else
auto
lds
=
reinterpret_cast
<
uint8_t
*>
(
src
.
data
().
get
());
auto
layout
=
src
.
layout
();
constexpr
short
offset
=
layout
(
0
,
row
,
col
)
*
1
;
lds
+=
offset
;
dst
=
__builtin_hcu_ds_read_m32x32_i8_alt2
((
__attribute__
((
address_space
(
3
)))
int
*
)(
lds
));
#endif
}
}
#endif
#endif
/*
/*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment