Commit 071ca121 authored by ltqin's avatar ltqin
Browse files

fix k0perthread and gridewis gemm main loop

parent 2159921e
......@@ -43,7 +43,7 @@ using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#define NORMAL_CONFIG 0
#define NORMAL_CONFIG 1
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
......@@ -229,7 +229,7 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
#if 1
#if 0
{
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
......
......@@ -41,6 +41,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t K0PerThread =
BK0NK1BlockDesc{}.GetLength(I0) / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......@@ -278,7 +280,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(0, k / KPack, 0, n0, 0, 0, i))>{}];
make_tuple(0, 0, k / KPack, 0, n0, 0, 0, i))>{}];
});
using mfma_input_type =
......@@ -304,11 +306,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
// B[N0, N1, N2, KPerThread]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, // KPerThread
I1, // NBlockId
Number<NRepeat>{}, // repeat
I1, // waves
I1, // NPerXdlops
I1,
Number<K0PerThread>{}, // KPerThread
I1, // NBlockId
Number<NRepeat>{}, // repeat
I1, // waves
I1, // NPerXdlops
Number<KPack>{}));
// C[M, N, NumRegXdlops]
......
......@@ -207,7 +207,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
......@@ -359,12 +359,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K0 / K0PerBlock, K0PerBlock)),
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4, 5>{}, Sequence<6>{}));
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
......@@ -383,7 +384,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
__device__ static auto GetWaveKNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0PerThread, NPerXDL))),
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
......@@ -559,7 +560,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// B matrix blockwise copy
constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<K0PerThread>{}, // K0PerThread
I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId
Number<NXdlPerWave>{}, // repeat
I1, // waves
......@@ -590,21 +592,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k0b_n0_n1_n2_n3_k1.GetLength(I0));
#endif
auto b_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<I1, Number<K0PerThread>{}, I1, Number<NXdlPerWave>{}, I1, I1, Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
make_multi_index(
0, wave_k_n_id[I0], block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
Sequence<I1,
I1,
Number<K0PerThread>{},
I1,
Number<NXdlPerWave>{},
I1,
I1,
Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -634,7 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
......@@ -642,7 +653,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
// Move
......@@ -666,15 +677,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// read b after gemm
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
block_sync_lds();
// move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment