Commit 3239201e authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix BlockwiseGemm.

Use default A/B MmaTileKStride equal to KPerThread.
parent aa5d4037
...@@ -266,9 +266,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -266,9 +266,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
true, // TransposeC true>; // TransposeC
KPack, // A MMaTileKStride // KPack // A MMaTileKStride
KPack>; // B MMaTileKStride // KPack>; // B MMaTileKStride
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
} }
...@@ -425,9 +425,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -425,9 +425,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
true, // TransposeC true>{}; // TransposeC
KPack, // A MMaTileKStride // KPack, // A MMaTileKStride
KPack>{}; // B MMaTileKStride // KPack>{}; // B MMaTileKStride
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -488,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -488,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
...@@ -517,37 +517,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -517,37 +517,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
if(threadIdx.x == 0 || threadIdx.x == 15 || threadIdx.x == 33 || threadIdx.x == 60) // if(threadIdx.x == 0
{ // // || threadIdx.x == 15
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< // // || threadIdx.x == 33
FloatGemmAcc, // // || threadIdx.x == 60
FloatC, // )
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), // {
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4), auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
CElementwiseOperation, FloatGemmAcc,
Sequence<M0, N0, I1, I1, I1, N2, I1, N4>, FloatC,
Sequence<0, 2, 4, 1, 3, 5, 6, 7>, // CThreadTransferDstAccessOrder, decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
7, // CThreadTransferDstVectorDim, decltype(c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4),
N4.value, // CThreadTransferDstScalarPerVector, CElementwiseOperation,
CGlobalMemoryDataOperation, Sequence<M0, N0, I1, I1, I1, N2, I1, N4>,
1, Sequence<0, 2, 4, 1, 3, 5, 6, 7>, // CThreadTransferDstAccessOrder,
true>{c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4, 7, // CThreadTransferDstVectorDim,
make_multi_index(m_thread_data_on_grid_idx[I0], N4.value, // CThreadTransferDstScalarPerVector,
n_thread_data_on_grid_idx[I0], CGlobalMemoryDataOperation,
m_thread_data_on_grid_idx[I1], 1,
n_thread_data_on_grid_idx[I1], true>{c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
m_thread_data_on_grid_idx[I2], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I2], n_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I4]), n_thread_data_on_grid_idx[I1],
c_element_op}; m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I2],
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, n_thread_data_on_grid_idx[I3],
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), n_thread_data_on_grid_idx[I4]),
c_thread_buf, c_element_op};
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_grid_buf); c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
} make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_grid_buf);
// }
// TODO: how SpaceFillingCurve works ? // TODO: how SpaceFillingCurve works ?
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment