"vscode:/vscode.git/clone" did not exist on "ead87d72c2c51c6dc9acb71b9ac971a989176a69"
Commit 0d02519a authored by wangshaojie6's avatar wangshaojie6
Browse files

add multik0

parent 4bda6db0
......@@ -49,7 +49,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 4, true, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 2>;
......
......@@ -49,7 +49,7 @@ __global__ void
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
//p_shared,
// p_shared,
a_grid_desc_k0_m_k1,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -115,7 +115,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto MultiK0 = 4 * 1;
static constexpr auto MultiK0 = 8 * 1;
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -238,15 +238,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1 = transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
M / (MXdlPerWave * MWaves * MPerXDL), MXdlPerWave, MWaves, MPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
const auto a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1 =
transform_tensor_descriptor(
a_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
M / (MXdlPerWave * MWaves * MPerXDL), MXdlPerWave, MWaves, MPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return a_griddesc_k0_mblockid_mrepeat_mwaves_mperxdlops_k1;
}
......@@ -256,15 +257,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
const auto b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1 =
transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_nwaves_nperxdlops_k1;
}
......@@ -399,7 +401,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
//void* __restrict__ p_shared,
// void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const AGridDesc_K0_K1_K2_M0_M1_M2_M3_K3 a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -446,8 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
FloatAB,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
true>
a_thread_buf_0, a_thread_buf_1, a_thread_buf_2, a_thread_buf_3;
a_thread_buf[MultiK0]; //, a_thread_buf_1, a_thread_buf_2, a_thread_buf_3;
ignore = b_element_op;
// B matrix threadwise copy
......@@ -465,7 +466,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>
b_thread_buf_0, b_thread_buf_1, b_thread_buf_2, b_thread_buf_3;
b_thread_buf[MultiK0]; //_0, b_thread_buf_1, b_thread_buf_2, b_thread_buf_3;
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -513,7 +514,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
make_multi_index(
0, wave_k_m_id[I0], 0, block_work_idx[I0], 0, wave_id[I1], wave_k_m_id[I1], 0));
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
......@@ -561,250 +561,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// gridwise GEMM pipeline
//constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
// constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr auto a_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_0);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_0);
// Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// Read
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_1);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_1);
// Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// Read
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_2);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_2);
// Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
index_t i_pre = 0;
do
{
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf[i_pre]);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf[i_pre]);
// Move
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
i_pre++;
} while(i_pre < MultiK0);
// Initialize C
c_thread_buf.Clear();
// a data write to lds
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t K0BlockMainLoop =
__builtin_amdgcn_readfirstlane(K0 / (MultiK0 * K0PerBlock));
__builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
index_t i = 0;
do
{
index_t i_k = 0;
do
{
blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_0);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_0);
blockwise_gemm.Run(a_thread_buf_1, b_thread_buf_1, c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
blockwise_gemm.Run(a_thread_buf[i_k], b_thread_buf[i_k], c_thread_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_1);
b_thread_buf[i_k]);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_1);
a_thread_buf[i_k]);
blockwise_gemm.Run(a_thread_buf_2, b_thread_buf_2, c_thread_buf);
asm volatile("s_nop 0" ::);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
i_k++;
} while(i_k < MultiK0);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_2);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_2);
blockwise_gemm.Run(a_thread_buf_3, b_thread_buf_3, c_thread_buf);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
}
i += 1;
} while(i < (K0BlockMainLoop - 1));
i += MultiK0;
} while(i < (K0BlockMainLoop - MultiK0));
}
// tail
{
static_for<0, MultiK0, 4>{}([&](auto i) {
blockwise_gemm.Run(a_thread_buf_0, b_thread_buf_0, c_thread_buf);
if constexpr(i < MultiK0 - 4)
{ // only move b windows
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_0);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_0);
}
blockwise_gemm.Run(a_thread_buf_1, b_thread_buf_1, c_thread_buf);
static_for<0, MultiK0, 1>{}([&](auto i) {
blockwise_gemm.Run(a_thread_buf[i], b_thread_buf[i], c_thread_buf);
if constexpr(i < MultiK0 - 4)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_1);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_1);
}
blockwise_gemm.Run(a_thread_buf_2, b_thread_buf_2, c_thread_buf);
if constexpr(i < MultiK0 - 4)
{
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_2);
b_thread_buf[i]);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_2);
}
blockwise_gemm.Run(a_thread_buf_3, b_thread_buf_3, c_thread_buf);
if constexpr(i < MultiK0 - 4)
{
a_thread_buf[i]);
// only move b windows
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_threadwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_thread_slice_copy_step);
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf_3);
a_threadwise_copy.Run(a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3,
a_grid_buf,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
a_thread_buf_3);
}
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment