Commit 4f88629d authored by ltqin's avatar ltqin
Browse files

fix name

parent 6b4c298c
......@@ -347,7 +347,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
......@@ -549,7 +549,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ignore = b_element_op;
// B matrix blockwise copy
constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 =
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // K0PerThread
......@@ -558,13 +558,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
I1, // waves
I1, // NPerXdlops
Number<K1>{}));
auto b_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1.GetElementSpaceSize(),
true>{};
auto b_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
auto b_grid_desc_k0_k0b_n0_n1_n2_n3_k1 =
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1);
auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
......@@ -584,14 +585,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k0b_n0_n1_n2_n3_k1.GetLength(I0));
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
Sequence<I1,
I1,
Number<K0PerThread>{},
......@@ -605,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
......@@ -622,7 +623,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerXDL,
NPerXDL,
MXdlPerWave,
......@@ -642,15 +643,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
{
// Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
// Initialize C
......@@ -671,16 +672,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// read b after gemm
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf);
block_sync_lds();
// move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1,
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment