Commit 071ca121 authored by ltqin's avatar ltqin
Browse files

fix k0perthread and gridewis gemm main loop

parent 2159921e
...@@ -43,7 +43,7 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -43,7 +43,7 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#define NORMAL_CONFIG 0 #define NORMAL_CONFIG 1
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
...@@ -229,7 +229,7 @@ int main(int argc, char* argv[]) ...@@ -229,7 +229,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#if 1 #if 0
{ {
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
......
...@@ -41,6 +41,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -41,6 +41,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread =
BK0NK1BlockDesc{}.GetLength(I0) / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
...@@ -278,7 +280,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -278,7 +280,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(0, k / KPack, 0, n0, 0, 0, i))>{}]; make_tuple(0, 0, k / KPack, 0, n0, 0, 0, i))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -304,11 +306,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -304,11 +306,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// B[N0, N1, N2, KPerThread] // B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, // KPerThread I1,
I1, // NBlockId Number<K0PerThread>{}, // KPerThread
Number<NRepeat>{}, // repeat I1, // NBlockId
I1, // waves Number<NRepeat>{}, // repeat
I1, // NPerXdlops I1, // waves
I1, // NPerXdlops
Number<KPack>{})); Number<KPack>{}));
// C[M, N, NumRegXdlops] // C[M, N, NumRegXdlops]
......
...@@ -207,7 +207,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -207,7 +207,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
...@@ -359,12 +359,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -359,12 +359,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K0 / K0PerBlock, K0PerBlock)), make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}, Sequence<6>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
} }
...@@ -383,7 +384,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -383,7 +384,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
__device__ static auto GetWaveKNIdx(const index_t thread_id) __device__ static auto GetWaveKNIdx(const index_t thread_id)
{ {
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0PerThread, NPerXDL))), make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -559,7 +560,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -559,7 +560,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// B matrix blockwise copy // B matrix blockwise copy
constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 = constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<K0PerThread>{}, // K0PerThread I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId I1, // NBlockId
Number<NXdlPerWave>{}, // repeat Number<NXdlPerWave>{}, // repeat
I1, // waves I1, // waves
...@@ -590,21 +592,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -590,21 +592,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
wave_id[I2], wave_id[I2],
wave_k_n_id[I0], wave_k_n_id[I0],
wave_k_n_id[I1]); wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k0b_n0_n1_n2_n3_k1.GetLength(I0));
#endif #endif
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto b_threadwise_copy =
FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<I1, Number<K0PerThread>{}, I1, Number<NXdlPerWave>{}, I1, I1, Number<K1>{}>, Sequence<I1,
Sequence<0, 1, 2, 3, 4, 5, 6>, I1,
6, Number<K0PerThread>{},
1, I1,
BThreadTransferSrcResetCoordinateAfterRun, Number<NXdlPerWave>{},
true>(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, I1,
make_multi_index( I1,
0, wave_k_n_id[I0], block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0)); Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -634,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -634,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// gridwise GEMM pipeline // gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS // preload data to regiester and LDS
{ {
// Read // Read
...@@ -642,7 +653,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -642,7 +653,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1, b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
make_tuple(I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// Move // Move
...@@ -666,15 +677,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -666,15 +677,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds(); block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// read b after gemm
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1, b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
make_tuple(I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
// move windows // move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment