Commit 920a752b authored by Adam Osewski's avatar Adam Osewski
Browse files

clang-format

parent 209c1e50
......@@ -717,53 +717,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
#if 0
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (karg.K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
......@@ -786,7 +739,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
#endif
// output: register to global memory
{
......
......@@ -767,10 +767,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3 =
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3 =
make_naive_tensor_descriptor_packed(
make_tuple(M0, N0, I1, I1, I2, I1, I1, Number<8>{}));
make_tuple(M0, N0, I1, I1, I2, I1, I1, Number<8>{}));
const auto M0_grid = c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
const auto N0_grid = c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
......@@ -781,10 +780,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const auto N3_grid = c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
const auto N4_grid = c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
// printf("grid: [M0: %d, N0: %d, M1: %d, N1: %d, M2: %d, N2: %d, N3: %d, N4: %d]\n",
// printf("grid: [M0: %d, N0: %d, M1: %d, N1: %d, M2: %d, N2: %d, N3: %d, N4:
// %d]\n",
// M0_grid,
// N0_grid,
// M1_grid,
......@@ -797,28 +796,26 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const auto c_grid_desc_m0_n0_m1_n1_m2_n234_tmp = transform_tensor_descriptor(
c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
make_pass_through_transform(M0_grid),
make_pass_through_transform(N0_grid),
make_pass_through_transform(M1_grid),
make_pass_through_transform(N1_grid),
make_pass_through_transform(M2_grid),
make_merge_transform(make_tuple(N3_grid, N2_grid, N4_grid)) // num_groups_per_blk * group_size
),
make_tuple(make_pass_through_transform(M0_grid),
make_pass_through_transform(N0_grid),
make_pass_through_transform(M1_grid),
make_pass_through_transform(N1_grid),
make_pass_through_transform(M2_grid),
make_merge_transform(make_tuple(
N3_grid, N2_grid, N4_grid)) // num_groups_per_blk * group_size
),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{}
),
Sequence<5, 6, 7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}
));
Sequence<5>{}));
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
......@@ -834,28 +831,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new = transform_tensor_descriptor(
c_grid_desc_m0_n0_m1_n1_m2_n234_tmp,
make_tuple(
make_pass_through_transform(M0_grid), // M0 - MRepeat / MXdlPerWave
make_pass_through_transform(N0_grid), // N0 - NRepeat / NXdlPerWave
make_pass_through_transform(M1_grid), // M1 - MWaves
make_pass_through_transform(N1_grid), // N1 - NWaves
make_unmerge_transform(make_tuple(I2, Number<16>{})), // M2 -> (M2: 2, M3: 16)
make_unmerge_transform(make_tuple(I4, Number<8>{})) // N2, N3
make_pass_through_transform(M0_grid), // M0 - MRepeat / MXdlPerWave
make_pass_through_transform(N0_grid), // N0 - NRepeat / NXdlPerWave
make_pass_through_transform(M1_grid), // M1 - MWaves
make_pass_through_transform(N1_grid), // N1 - NWaves
make_unmerge_transform(make_tuple(I2, Number<16>{})), // M2 -> (M2: 2, M3: 16)
make_unmerge_transform(make_tuple(I4, Number<8>{})) // N2, N3
),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}
),
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5>{},
Sequence<6, 7>{}
)
);
Sequence<6, 7>{}));
// if (blockIdx.x == 0 && ThisThreadBlock::GetThreadId() == 0)
// {
......@@ -872,13 +866,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const auto wave_idx = blockwise_gemm.GetWaveIdx();
const auto lane_id_to_m3_n2_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Number<16>{}, I4))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{})
);
make_tuple(make_merge_transform(make_tuple(Number<16>{}, I4))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
const auto lane_data_idx_on_block = lane_id_to_m3_n2_adaptor.CalculateBottomIndex(
make_multi_index(wave_idx[I2]));
const auto lane_data_idx_on_block =
lane_id_to_m3_n2_adaptor.CalculateBottomIndex(make_multi_index(wave_idx[I2]));
// if (blockIdx.x == 0 && (ThisThreadBlock::GetThreadId() == 0 ||
// ThisThreadBlock::GetThreadId() == 16 ||
......@@ -918,10 +911,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, I2, I1, I1, 8>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // CThreadTransferDstAccessOrder,
7, // CThreadTransferDstVectorDim,
8, // CThreadTransferDstScalarPerVector,
Sequence<M0, N0, I1, I1, I2, I1, I1, 8>, // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // CThreadTransferDstAccessOrder,
7, // CThreadTransferDstVectorDim,
8, // CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_n0_m1_n1_m2_m3_n2_n3_new,
......@@ -936,18 +929,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
c_element_op};
// if (blockIdx.x == 0 || blockIdx.x == 5)
// { // M1, N1, M2, N2, N3
// { // M1, N1, M2, N2,
// N3
// if (ThisThreadBlock::GetThreadId() == 0 ||
// ThisThreadBlock::GetThreadId() == 3 || // [ 0, 0, 0, 3, 0]
// ThisThreadBlock::GetThreadId() == 16 || // [ 0, 0, 4, 0, 0]
// ThisThreadBlock::GetThreadId() == 33 || // [ 0, 0, 8, 1, 0]
// ThisThreadBlock::GetThreadId() == 64 || // [ 0, 1, 0, 0, 0]
// ThisThreadBlock::GetThreadId() == 96 || // [ 0, 1, 8, 0, 0]
// ThisThreadBlock::GetThreadId() == 130 || // [ 1, 0, 0, 2, 0]
// ThisThreadBlock::GetThreadId() == 224 // [ 1, 1, 8, 0, 0]
// ThisThreadBlock::GetThreadId() == 3 || // [ 0, 0, 0, 3,
// 0] ThisThreadBlock::GetThreadId() == 16 || // [ 0, 0, 4, 0,
// 0] ThisThreadBlock::GetThreadId() == 33 || // [ 0, 0, 8, 1,
// 0] ThisThreadBlock::GetThreadId() == 64 || // [ 0, 1, 0, 0,
// 0] ThisThreadBlock::GetThreadId() == 96 || // [ 0, 1, 8, 0,
// 0] ThisThreadBlock::GetThreadId() == 130 || // [ 1, 0, 0, 2,
// 0] ThisThreadBlock::GetThreadId() == 224 // [ 1, 1, 8, 0,
// 0]
// )
// {
// printf("[B:%d, T:%d] -> dst_slice_origin_idx: [%d, %d, %d, %d, %d, %d, %d]\n",
// printf("[B:%d, T:%d] -> dst_slice_origin_idx: [%d, %d, %d, %d, %d, %d,
// %d]\n",
// get_block_1d_id(),
// ThisThreadBlock::GetThreadId(),
// m_thread_data_on_grid_idx[I0],
......@@ -960,7 +956,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
// }
// }
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_n2_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment