Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3db3fe42
Commit
3db3fe42
authored
Dec 14, 2022
by
Anthony Chang
Browse files
dP
parent
c26b46de
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
232 additions
and
39 deletions
+232
-39
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+19
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+213
-37
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
3db3fe42
...
...
@@ -364,12 +364,29 @@ int run(int argc, char* argv[])
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
break
;
default:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
break
;
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
// ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
// dy[g0, g1, m, o]
break
;
case
6
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
3
>
{});
// dy[g0, g1, m, o]
break
;
default:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
DataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
// dy[g0, g1, m, o]
}
// calculate y & log-sum-exp beforehand
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
3db3fe42
...
...
@@ -30,7 +30,7 @@ template <typename DataType,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1
GridDesc_
BK0_N_BK
1
,
typename
V
GridDesc_
N0_O_N
1
,
typename
CGridDesc_M_N
,
typename
LSEGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
...
...
@@ -186,36 +186,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
[](
auto
I
)
{
return
GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2
().
At
(
I
);
},
Number
<
4
>
{});
}
// template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
// {
// constexpr auto lengths = GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2();
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<lengths[I0], lengths[I2],
// lengths[I4]>(
// PBlockDesc_M0_N_M1{});
// }
// template <typename BBlockDesc_BK0_N_BK1>
// __host__ __device__ static constexpr auto
// MakeYGradMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
// {
// constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
// return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
// BBlockDesc_BK0_N_BK1{});
// }
};
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
// PGrad Gemm
struct
PGradGemmTile_M_N_O_
{
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
{
...
...
@@ -363,7 +337,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1
GridDesc_
BK0_N_BK
1
&
b1
_grid_desc_
bk0_n_bk
1
,
const
V
GridDesc_
N0_O_N
1
&
v
_grid_desc_
n0_o_n
1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
...
...
@@ -374,7 +348,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
b1
_grid_desc_
bk0_n_bk
1
.
GetLength
(
I1
);
const
auto
Gemm1N
=
v
_grid_desc_
n0_o_n
1
.
GetLength
(
I1
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
Gemm1N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
...
...
@@ -472,6 +446,81 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
// PGrad Gemm has the same layout as P Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
{
private:
static
constexpr
auto
ygrad_block_desc_o0_m_o1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
v_block_desc_o0_n_o1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
public:
using
BlockwiseGemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
DataType
,
FloatGemmAcc
,
decltype
(
ygrad_block_desc_o0_m_o1
),
decltype
(
v_block_desc_o0_n_o1
),
decltype
(
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
ygrad_block_desc_o0_m_o1
)),
decltype
(
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
(
v_block_desc_o0_n_o1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
true
>
;
// Should have made all input tensors 2D and transform them into appropriate 3D form in
// kernel to make things more concise - if we can get the compiler to behave
template
<
typename
YGradGridDesc_M0_O_M1_
>
__device__
static
const
auto
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
{
const
auto
M0
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I0
);
const
auto
O
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I1
);
const
auto
M1
=
ygrad_grid_desc_m0_o_m1
.
GetLength
(
I2
);
constexpr
auto
Y_O1
=
AK1
;
const
auto
Y_O0
=
O
/
Y_O1
;
const
auto
ygrad_grid_desc_o0_m_o1
=
transform_tensor_descriptor
(
ygrad_grid_desc_m0_o_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_O0
,
Y_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
M0
,
M1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
ygrad_grid_desc_o0_m_o1
;
}
template
<
typename
VGridDesc_N0_O_N1_
>
__device__
static
const
auto
MakeVGridDesc_O0_N_O1
(
const
VGridDesc_N0_O_N1_
&
v_grid_desc_n0_o_n1
)
{
const
auto
N0
=
v_grid_desc_n0_o_n1
.
GetLength
(
I0
);
const
auto
O
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
N1
=
v_grid_desc_n0_o_n1
.
GetLength
(
I2
);
constexpr
auto
V_O1
=
BK1
;
const
auto
V_O0
=
O
/
V_O1
;
const
auto
v_grid_desc_o0_n_o1
=
transform_tensor_descriptor
(
v_grid_desc_n0_o_n1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
V_O0
,
V_O1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
v_grid_desc_o0_n_o1
;
}
};
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -525,7 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename
YGradGridDesc_M0_O_M1
>
__device__
static
void
Run
(
const
DataType
*
__restrict__
p_a_grid
,
const
DataType
*
__restrict__
p_b_grid
,
const
DataType
*
__restrict__
p_
b1
_grid
,
const
DataType
*
__restrict__
p_
v
_grid
,
const
DataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
DataType
*
__restrict__
p_ygrad_grid
,
...
...
@@ -540,7 +589,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1
GridDesc_
BK0_N_BK
1
&
b1
_grid_desc_
bk0_n_bk
1
,
const
V
GridDesc_
N0_O_N
1
&
v
_grid_desc_
n0_o_n
1
,
const
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
&
y_grid_desc_mblock_mperblock_oblock_operblock
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
...
...
@@ -553,8 +602,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
b1
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
b1
_grid
,
b1
_grid_desc_
bk0_n_bk
1
.
GetElementSpaceSize
());
const
auto
v
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
v
_grid
,
v
_grid_desc_
n0_o_n
1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -784,7 +833,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1BlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
b1
_grid_desc_
bk0_n_bk
1
),
decltype
(
v
_grid_desc_
n0_o_n
1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
...
...
@@ -797,7 +846,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1
_grid_desc_
bk0_n_bk
1
,
v
_grid_desc_
n0_o_n
1
,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
...
...
@@ -1298,6 +1347,83 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// A matrix blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
ygrad_grid_desc_o0_m_o1
),
decltype
(
a_block_desc_ak0_m_ak1
),
// reuse block buf
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
DataType
,
DataType
,
decltype
(
v_grid_desc_o0_n_o1
),
decltype
(
b_block_desc_bk0_n_bk1
),
// reuse block buf
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
pgrad_blockwise_gemm
=
typename
PGradGemmTile_M_N_O
::
BlockwiseGemm
{};
auto
pgrad_acc_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
...
...
@@ -1525,7 +1651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
#endif
// softmax
//
P_i: =
softmax
(S_i:)
blockwise_softmax
.
RunWithPreCalcStats
(
acc_thread_buf
,
lse_thread_buf
);
#if 0
...
...
@@ -1628,13 +1754,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
#endif
});
// end gemm dV
// atomic_add
vgra
d
// atomic_add d
V
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
vgrad_acc_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
// gemm dP
pgrad_acc_thread_buf
.
Clear
();
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("j loop idx %d, tid %zd, clear dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_acc_thread_buf[I0],
pgrad_acc_thread_buf[I1],
pgrad_acc_thread_buf[I2],
pgrad_acc_thread_buf[I3]);
}
#endif
block_sync_lds
();
// assume size K == size O so has main block loop
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
vgrad_grid_buf
,
a_block_buf
,
// reuse
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
b_block_buf
,
// reuse
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_acc_thread_buf
,
num_o_block_main_loop
);
#if 1
if
(
hipBlockIdx_x
==
0
&&
hipThreadIdx_x
%
32
<
4
)
{
printf
(
"j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f
\n
"
,
gemm1_k_block_outer_index
,
hipThreadIdx_x
,
pgrad_acc_thread_buf
[
I0
],
pgrad_acc_thread_buf
[
I1
],
pgrad_acc_thread_buf
[
I2
],
pgrad_acc_thread_buf
[
I3
]);
}
#endif
// move slice window
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
...
...
@@ -1643,6 +1814,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
ygrad_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_block_slice_copy_step
);
// step N
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O and step N
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment