"...resnet50_tensorflow.git" did not exist on "a91f37796840fb0f56be17ea384a7e233ce68c20"
Commit 9ae3308a authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix BlockwiseGemm.

Use default A/B MmaTileKStride equal to KPerThread.
parent 0a65fc55
......@@ -266,9 +266,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave,
NXdlPerWave,
KPack,
true, // TransposeC
KPack, // A MMaTileKStride
KPack>; // B MMaTileKStride
true>; // TransposeC
// KPack // A MMaTileKStride
// KPack>; // B MMaTileKStride
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
}
......@@ -425,9 +425,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
MXdlPerWave,
NXdlPerWave,
KPack,
true, // TransposeC
KPack, // A MMaTileKStride
KPack>{}; // B MMaTileKStride
true>{}; // TransposeC
// KPack, // A MMaTileKStride
// KPack>{}; // B MMaTileKStride
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......@@ -488,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
......@@ -517,37 +517,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_v1
n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
if(threadIdx.x == 0 || threadIdx.x == 15 || threadIdx.x == 33 || threadIdx.x == 60)
{
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, I1, N2, I1, N4>,
Sequence<0, 2, 4, 1, 3, 5, 6, 7>, // CThreadTransferDstAccessOrder,
7, // CThreadTransferDstVectorDim,
N4.value, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I3],
n_thread_data_on_grid_idx[I4]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_grid_buf);
}
// if(threadIdx.x == 0
// // || threadIdx.x == 15
// // || threadIdx.x == 33
// // || threadIdx.x == 60
// )
// {
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, I1, N2, I1, N4>,
Sequence<0, 2, 4, 1, 3, 5, 6, 7>, // CThreadTransferDstAccessOrder,
7, // CThreadTransferDstVectorDim,
N4.value, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I2],
n_thread_data_on_grid_idx[I3],
n_thread_data_on_grid_idx[I4]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_grid_buf);
// }
// TODO: how SpaceFillingCurve works ?
// space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment