Commit ec9c2b5e authored by Anthony Chang's avatar Anthony Chang
Browse files

dK validates

parent 2d55c14c
...@@ -647,7 +647,10 @@ int run(int argc, char* argv[]) ...@@ -647,7 +647,10 @@ int run(int argc, char* argv[])
1e-2); 1e-2);
std::cout << "Checking kgrad:\n"; std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData, pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData); kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n"; std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData, pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData, vgrad_gs_os_ns_host_result.mData,
...@@ -656,7 +659,7 @@ int run(int argc, char* argv[]) ...@@ -656,7 +659,7 @@ int run(int argc, char* argv[])
1e-2); 1e-2);
} }
return pass ? 0 : 1; return pass ? (std::cout << "pass\n", 0) : (std::cout << "fail\n", 1);
} }
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -789,7 +789,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -789,7 +789,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
template <typename CGradDesc_N_O> template <typename CGradDesc_N_O>
__host__ __device__ static const auto __host__ __device__ static const auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(CGradDesc_N_O c_grid_desc_n_o) MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(const CGradDesc_N_O& c_grid_desc_n_o)
{ {
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy // HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there // variable I1 there
...@@ -859,7 +859,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -859,7 +859,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major) // PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O struct PGradGemmTile_M_N_O
{ {
// TODO ANT: // TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make // Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise // things more concise
template <typename YGradGridDesc_M0_O_M1_> template <typename YGradGridDesc_M0_O_M1_>
...@@ -957,6 +957,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -957,6 +957,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}; };
struct KGradGemmTile_N_K_M
{
// B position
template <typename QGridDesc_K0_M_K1_>
__device__ static const auto
MakeQGridDesc_M0_K_M1(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto Q_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto Q_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
constexpr auto Q_M1 = B1K1;
const auto Q_M0 = M / Q_M1;
const auto q_grid_desc_m0_k_m1 = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(make_tuple(Q_M0, Q_M1)),
make_merge_transform_v3_division_mod(make_tuple(Q_K0, Q_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return q_grid_desc_m0_k_m1;
}
// C position
template <typename KGridDesc_K0_N_K1_>
__device__ static const auto
MakeKGradGridDesc_N_K(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
return transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -1067,7 +1109,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1067,7 +1109,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize()); p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize()); p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize()); p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize()); p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
...@@ -1075,6 +1117,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1075,6 +1117,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize()); p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [M, O] // divide block work by [M, O]
const auto block_work_idx = const auto block_work_idx =
...@@ -1095,6 +1139,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1095,6 +1139,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t o_block_data_idx_on_grid = const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// //
// set up S / dP Gemm (type 1 rcr) // set up S / dP Gemm (type 1 rcr)
// //
...@@ -1211,11 +1260,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1211,11 +1260,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock( QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1); q_grid_desc_k0_m_k1);
// dQ: Gemm A matrix blockwise copy // dQ: A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy = auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}}; typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ: Gemm B matrix blockwise copy // dQ: B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy = auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>( typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1, k_grid_desc_n0_k_n1,
...@@ -1357,9 +1406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1357,9 +1406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// dV: blockwise gemm // dV: blockwise gemm
auto vgrad_blockwise_gemm = typename Gemm2::BlockwiseGemm{}; auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto vgrad_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer(); auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer();
// dV: C VGPR-to-global copy // dV: C VGPR-to-global copy
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 = const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
...@@ -1376,6 +1425,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1376,6 +1425,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
const auto kgrad_grid_desc_n_k =
KGradGemmTile_N_K_M::MakeKGradGridDesc_N_K(k_grid_desc_k0_n_k1);
// dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix global-to-LDS blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const auto kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
const auto kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(I0, block_work_idx[I1], I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::PassThrough{});
// //
// set up Y dot dY // set up Y dot dY
...@@ -1618,38 +1706,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1618,38 +1706,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
constexpr auto p_block_slice_lengths_m0_n0_m1_n1 = SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1{};
SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]); s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_vgrad_gemm_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M; constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(vgrad_gemm_tile_p_block_slice_window_iterator.GetNumOfAccess() == static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
num_vgrad_gemm_loop,
""); "");
// TODO: tune gemm2 pipeline // TODO: tune gemm2 pipeline
// dV = P^T * dY // dV = P^T * dY
vgrad_thread_buf.Clear(); v_slash_k_grad_thread_buf.Clear();
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B // load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf); ygrad_grid_buf);
// load VGrad Gemm A // load VGrad Gemm A
const auto p_nd_idx = const auto p_slice_idx =
vgrad_gemm_tile_p_block_slice_window_iterator.GetIndexTupleOfNumber( Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
vgrad_gemm_loop_idx); constexpr auto mwave_range = make_tuple(
constexpr auto mwave_range = p_slice_idx[I2],
make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]); p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range = constexpr auto nwave_range = make_tuple(
make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]); p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0), make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
s_slash_p_thread_buf, s_slash_p_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf); gemm2_a_block_buf);
...@@ -1665,13 +1751,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1665,13 +1751,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
gemm2_b_block_buf); gemm2_b_block_buf);
block_sync_lds(); // sync before read block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(gemm2_a_block_buf, gemm2_b_block_buf, vgrad_thread_buf); v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dV }); // end gemm dV
// atomic_add dV // atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
vgrad_thread_buf, v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf); vgrad_grid_buf);
...@@ -1777,6 +1864,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1777,6 +1864,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
} // end gemm dQ } // end gemm dQ
// dK = dS^T * dQ
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// load KGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range =
make_tuple(sgrad_slice_idx[I2],
sgrad_slice_idx[I2] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range =
make_tuple(sgrad_slice_idx[I3],
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_grid_buf);
// move slice window // move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow( s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
...@@ -1794,6 +1931,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1794,6 +1931,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1,
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment