Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
34489f46
Commit
34489f46
authored
Feb 27, 2026
by
zhanghj2
Browse files
优化代码
parent
98b7c697
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
50 deletions
+64
-50
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+64
-50
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
34489f46
...
@@ -1052,6 +1052,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1052,6 +1052,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
,
SharedStorage
&
shared_storage
,
const
float
descale_k
,
const
float
scale_softmax
,
const
float
scale_softmax_log2
)
{
SharedStorage
&
shared_storage
,
const
float
descale_k
,
const
float
scale_softmax
,
const
float
scale_softmax_log2
)
{
if
(
n_block_max
<=
n_block_min
)
{
return
;
}
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
...
@@ -1196,9 +1199,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1196,9 +1199,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
constexpr
static
int
STAGE
=
8
;
constexpr
static
int
STAGE
=
8
;
struct
IsMaskBlock
{};
#if 1
struct
IsFirstMaskBlock
{};
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
struct
IsNoMaskBlock
{};
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_NO_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
v4f
accs_f32
[
2
];
v4f
accs_f32
[
2
];
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
{
...
@@ -1207,30 +1214,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1207,30 +1214,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// clear(acc_s);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// asm volatile("s_barrier \n\t");
// asm volatile("s_barrier \n\t");
int
cur_block_table
;
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_
block
;
const
int
*
cur_block_table_ptr
=
block_table
+
block
_idx
;
// cur_block_table = block_table[
n_
block - 1];
// cur_block_table = block_table[block
_idx
- 1];
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
// gK.data() = gK.data() + (offset_k);
// gK.data() = gK.data() + (offset_k);
#if 1
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
auto
gK_offset
=
((
warp_id
)
/
4
)
*
64
+
((
warp_id
)
%
4
)
*
16
*
params
.
k_row_stride
;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// auto gK_offset = (offset_k) + ((warp_id) / 4) * 64 + ((warp_id) % 4) * 16 * params.k_row_stride;
// const int k_zero_pad = std::min(std::max(
n_
block * kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
// const int k_zero_pad = std::min(std::max(block
_idx
* kBlockN + ((warp_id) % 4 + 1) * 16 - seqlen_k, 0), 16);
const
int
k_zero_pad
=
std
::
max
(
n_
block
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
const
int
k_zero_pad
=
std
::
max
(
block
_idx
*
kBlockN
+
((
warp_id
)
%
4
+
1
)
*
16
-
seqlen_k
,
0
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
uint32x4_t
gK_rscr
=
make_rscr
((
unsigned
char
*
)(
gK
.
data
().
get
()
+
gK_offset
),
params
.
k_row_stride
,
k_zero_pad
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
auto
k_lds_addr
=
reinterpret_cast
<
size_t
>
(
sK
.
data
().
get
()
+
((
warp_id
)
/
4
)
*
64
*
64
+
(
warp_id
%
4
)
*
16
*
64
);
if
(
n_
block
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
||
masking_step
!=
0
)
if
(
block
_idx
*
kBlockN
+
((
warp_id
)
%
4
)
*
16
<
seqlen_k
||
IS_NO_MASK_BLOCK
)
{
{
k_lds_addr
|=
0x80000000
;
k_lds_addr
|=
0x80000000
;
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
__builtin_hcu_matrix_load_64x16_b8
(
gK_rscr
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
k_lds_addr
),
0
,
1
,
1
,
0
,
0
);
...
@@ -1405,33 +1408,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1405,33 +1408,26 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
0
)
=
accs_f32
[
0
].
x
;
acc_s
(
1
,
0
,
0
)
=
accs_f32
[
0
].
y
;
acc_s
(
2
,
0
,
0
)
=
accs_f32
[
0
].
z
;
acc_s
(
3
,
0
,
0
)
=
accs_f32
[
0
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
acc_s
(
0
,
0
,
1
)
=
accs_f32
[
1
].
x
;
acc_s
(
1
,
0
,
1
)
=
accs_f32
[
1
].
y
;
acc_s
(
2
,
0
,
1
)
=
accs_f32
[
1
].
z
;
acc_s
(
3
,
0
,
1
)
=
accs_f32
[
1
].
w
;
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
// cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
#endif
// #endif
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
block_idx
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_block
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
int
col_limit_right
=
seqlen_k
-
1
-
block_idx
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
}
}
// asm volatile("s_barrier \n\t");
{
const
bool
is_first_masking_step
=
masking_step
==
0
;
// is_first_masking_step
// ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2)
// : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(acc_s, acc_o, sRow_max_reduce_buffer, scale_softmax_log2);
is_first_masking_step
?
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
)
:
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
}
// asm volatile("s_barrier \n\t");
#if 1
softmax
.
template
softmax_rescale_o_fp8_tp1
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
,
true
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
acco_f32
);
// #if 1
Fp8_storage
p_fp8
;
Fp8_storage
p_fp8
;
{
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -1498,30 +1494,41 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1498,30 +1494,41 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
};
#if 1
if
constexpr
(
n_masking_steps
==
1
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
}
else
{
int
masking_step
=
1
;
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
for
(;
n_block
>=
n_block_min
&&
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
process_one_block
(
n_block
,
IsMaskBlock
{});
}
}
}
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
process_one_block
(
n_block
,
IsNoMaskBlock
{});
}
#endif
#endif
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementO
=
typename
Kernel_traits
::
ElementO
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
// Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(shared_storage.smem_row_sum.data()), typename Kernel_traits::SmemLayoutRow{});
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_v
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
if
(
NoSplit
)
{
if
(
NoSplit
)
{
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
constexpr
bool
Split
=
false
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8_tp1
<
/*Is_dropout=*/
false
,
Split
,
true
>(
acco_f32
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8_tp1
<
/*Is_dropout=*/
false
,
Split
,
true
>(
acco_f32
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
Split
?
row_offset_lseaccum
:
row_offset_lse
)),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
...
@@ -1598,12 +1605,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1598,12 +1605,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
}
else
{
}
else
{
constexpr
bool
Split
=
true
;
constexpr
bool
Split
=
true
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
const
int
split_offset
=
__ldg
(
params
.
num_splits_ptr
+
bidb
);
const
index_t
row_offset_oaccum
=
(((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_v
;
const
index_t
row_offset_lseaccum
=
((
split_offset
+
n_split_idx
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_oaccum
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8_tp1
<
/*Is_dropout=*/
false
,
Split
,
true
>(
acco_f32
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_fp8_tp1
<
/*Is_dropout=*/
false
,
Split
,
true
>(
acco_f32
,
sRow_sum_reduce_buffer
,
scale_softmax
,
descale_k
);
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
Split
?
row_offset_lseaccum
:
row_offset_lse
)),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
(
row_offset_lseaccum
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment