Commit 3239201e authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix BlockwiseGemm.

Use default A/B MmaTileKStride equal to KPerThread.
parent aa5d4037
...@@ -266,9 +266,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -266,9 +266,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
true, // TransposeC true>; // TransposeC
KPack, // A MMaTileKStride // KPack // A MMaTileKStride
KPack>; // B MMaTileKStride // KPack>; // B MMaTileKStride
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n); return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
} }
...@@ -425,9 +425,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -425,9 +425,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
KPack, KPack,
true, // TransposeC true>{}; // TransposeC
KPack, // A MMaTileKStride // KPack, // A MMaTileKStride
KPack>{}; // B MMaTileKStride // KPack>{}; // B MMaTileKStride
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -517,8 +517,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -517,8 +517,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
if(threadIdx.x == 0 || threadIdx.x == 15 || threadIdx.x == 33 || threadIdx.x == 60) // if(threadIdx.x == 0
{ // // || threadIdx.x == 15
// // || threadIdx.x == 33
// // || threadIdx.x == 60
// )
// {
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
FloatC, FloatC,
...@@ -547,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1 ...@@ -547,7 +551,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
c_thread_buf, c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_grid_buf); c_grid_buf);
} // }
// TODO: how SpaceFillingCurve works ? // TODO: how SpaceFillingCurve works ?
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment