Commit 4f88629d authored by ltqin's avatar ltqin
Browse files

fix name

parent 6b4c298c
...@@ -347,7 +347,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -347,7 +347,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{ {
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
...@@ -549,7 +549,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -549,7 +549,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
ignore = b_element_op; ignore = b_element_op;
// B matrix blockwise copy // B matrix blockwise copy
constexpr auto b_thread_desc_k0_k0b_n0_n1_n2_n3_k1 = constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
I1, I1,
Number<K0PerThread>{}, // K0PerThread Number<K0PerThread>{}, // K0PerThread
...@@ -558,13 +558,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -558,13 +558,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
I1, // waves I1, // waves
I1, // NPerXdlops I1, // NPerXdlops
Number<K1>{})); Number<K1>{}));
auto b_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr, auto b_thread_buf =
FloatAB, StaticBuffer<AddressSpaceEnum::Vgpr,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1.GetElementSpaceSize(), FloatAB,
true>{}; b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
auto b_grid_desc_k0_k0b_n0_n1_n2_n3_k1 = auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
MakeBGridDescriptor_K0_K0B_N0_N1_N2_N3_K1(b_grid_desc_k0_n_k1); MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
const auto wave_id = GetWaveIdx(); const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
...@@ -584,14 +585,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -584,14 +585,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
wave_k_n_id[I0], wave_k_n_id[I0],
wave_k_n_id[I1]); wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k0b_n0_n1_n2_n3_k1.GetLength(I0)); xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif #endif
auto b_threadwise_copy = auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB, ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB, FloatAB,
decltype(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
Sequence<I1, Sequence<I1,
I1, I1,
Number<K0PerThread>{}, Number<K0PerThread>{},
...@@ -605,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -605,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index( make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0)); 0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
...@@ -622,7 +623,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -622,7 +623,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
decltype(b_thread_desc_k0_k0b_n0_n1_n2_n3_k1), decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MXdlPerWave, MXdlPerWave,
...@@ -642,15 +643,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -642,15 +643,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
{ {
// Read // Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// Move // Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
// Initialize C // Initialize C
...@@ -671,16 +672,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3 ...@@ -671,16 +672,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
block_sync_lds(); block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_thread_buf, c_thread_buf);
// read b after gemm // read b after gemm
b_threadwise_copy.Run(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf, b_grid_buf,
b_thread_desc_k0_k0b_n0_n1_n2_n3_k1, b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
block_sync_lds(); block_sync_lds();
// move windows // move windows
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k0b_n0_n1_n2_n3_k1, b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step); b_thread_slice_copy_step);
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment