Commit 1654fcce authored by wangshaojie6's avatar wangshaojie6
Browse files

clang-format-10 change

parent fc17eb42
...@@ -81,36 +81,38 @@ struct Merge_v3_division_mod_for_wrw ...@@ -81,36 +81,38 @@ struct Merge_v3_division_mod_for_wrw
index_t tmp = idx_up_new[I0]; index_t tmp = idx_up_new[I0];
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){ // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2); // //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// //printf("%d, %d, %d\n", // //printf("%d, %d, %d\n",
// // __LINE__, // // __LINE__,
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())), // // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>()))); // // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>())); // printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()),
// idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//} //}
//static_for<0, NDimLow - 1, 1>{}([&](auto i) { // static_for<0, NDimLow - 1, 1>{}([&](auto i) {
// const index_t tmp2 = idx_low[i]; // const index_t tmp2 = idx_low[i];
// idx_low(i) = tmp / this->low_lengths_scan_[i]; // idx_low(i) = tmp / this->low_lengths_scan_[i];
// idx_diff_low(i) = idx_low[i] - tmp2; // idx_diff_low(i) = idx_low[i] - tmp2;
// tmp %= this->low_lengths_scan_[i]; // tmp %= this->low_lengths_scan_[i];
//}); //});
//const index_t tmp2 = idx_low[INm1]; // const index_t tmp2 = idx_low[INm1];
//idx_low(INm1) = tmp; // idx_low(INm1) = tmp;
//idx_diff_low(INm1) = idx_low[INm1] - tmp2; // idx_diff_low(INm1) = idx_low[INm1] - tmp2;
idx_low(INm1) = tmp; idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_up_diff[I0]; idx_diff_low(INm1) = idx_up_diff[I0];
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){ // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2); // //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// printf("%d, %d, %d\n", // printf("%d, %d, %d\n",
// __LINE__, // __LINE__,
// static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())), // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// static_cast<index_t>(this->low_lengths_scan_.At(Number<1>()))); // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>())); // printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()),
// idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//} //}
} }
...@@ -156,4 +158,4 @@ make_merge_transform_v3_division_mod_for_wrw(const LowLengths& low_lengths) ...@@ -156,4 +158,4 @@ make_merge_transform_v3_division_mod_for_wrw(const LowLengths& low_lengths)
return Merge_v3_division_mod_for_wrw<LowLengths>{low_lengths}; return Merge_v3_division_mod_for_wrw<LowLengths>{low_lengths};
} }
} } // namespace ck
...@@ -149,17 +149,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -149,17 +149,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
#if A_BLOCK_BANK_CONFLICT_FREE_WRW #if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1), make_tuple(
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M1PerBlock>{} * K1 + M1Padding, K1, I1)); Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1),
make_tuple(Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_k0_m0_m1_k1, a_block_desc_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}), make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})), make_merge_transform_v3_division_mod(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
);
return a_block_desc_k0_m_k1_tmp; return a_block_desc_k0_m_k1_tmp;
#else #else
...@@ -188,30 +192,43 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -188,30 +192,43 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
#if A_BLOCK_BANK_CONFLICT_FREE_WRW #if A_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<M0PerBlock>{}, Number<M1PerBlock>{}, K1), make_tuple(Number<1>{},
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding), Number<M1PerBlock>{} * K1 + M1Padding, K1, I1)); Number<K0PerBlock>{},
Number<M0PerBlock>{},
Number<M1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<M0PerBlock>{} *
(Number<M1PerBlock>{} * K1 + M1Padding),
Number<M0PerBlock>{} * (Number<M1PerBlock>{} * K1 + M1Padding),
Number<M1PerBlock>{} * K1 + M1Padding,
K1,
I1));
constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor(
a_block_desc_b_k0_m0_m1_k1, a_block_desc_b_k0_m0_m1_k1,
make_tuple(make_pass_through_transform(Number<1>{}), make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}), make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})), make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<M0PerBlock>{}, Number<M1PerBlock>{})),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
);
return a_block_desc_b_k0_m_k1_tmp; return a_block_desc_b_k0_m_k1_tmp;
#else #else
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1, Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<K0PerBlock>{} * Number<MPerBlock + 1>{} * K1,
Number<MPerBlock + 1>{} * K1,
K1,
I1));
#endif #endif
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
} }
}(); }();
...@@ -228,17 +245,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -228,17 +245,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
#if B_BLOCK_BANK_CONFLICT_FREE_WRW #if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1), make_tuple(
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N1PerBlock>{} * K1 + N1Padding, K1, I1)); Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_k0_n0_n1_k1, b_block_desc_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<K0PerBlock>{}), make_tuple(make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod(make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})), make_merge_transform_v3_division_mod(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
);
return b_block_desc_k0_n_k1_tmp; return b_block_desc_k0_n_k1_tmp;
#else #else
...@@ -268,30 +289,43 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -268,30 +289,43 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
#if B_BLOCK_BANK_CONFLICT_FREE_WRW #if B_BLOCK_BANK_CONFLICT_FREE_WRW
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<N0PerBlock>{}, Number<N1PerBlock>{}, K1), make_tuple(Number<1>{},
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding), Number<N1PerBlock>{} * K1 + N1Padding, K1, I1)); Number<K0PerBlock>{},
Number<N0PerBlock>{},
Number<N1PerBlock>{},
K1),
make_tuple(Number<K0PerBlock>{} * Number<N0PerBlock>{} *
(Number<N1PerBlock>{} * K1 + N1Padding),
Number<N0PerBlock>{} * (Number<N1PerBlock>{} * K1 + N1Padding),
Number<N1PerBlock>{} * K1 + N1Padding,
K1,
I1));
constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor(
b_block_desc_b_k0_n0_n1_k1, b_block_desc_b_k0_n0_n1_k1,
make_tuple(make_pass_through_transform(Number<1>{}), make_tuple(make_pass_through_transform(Number<1>{}),
make_pass_through_transform(Number<K0PerBlock>{}), make_pass_through_transform(Number<K0PerBlock>{}),
make_merge_transform_v3_division_mod_for_wrw(make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})), make_merge_transform_v3_division_mod_for_wrw(
make_tuple(Number<N0PerBlock>{}, Number<N1PerBlock>{})),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}) make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
);
return b_block_desc_b_k0_n_k1_tmp; return b_block_desc_b_k0_n_k1_tmp;
#else #else
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1, Number<NPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<K0PerBlock>{} * Number<NPerBlock + 1>{} * K1,
Number<NPerBlock + 1>{} * K1,
K1,
I1));
#endif #endif
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
} }
}(); }();
...@@ -309,11 +343,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -309,11 +343,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1();
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size = math::integer_least_multiple(
math::integer_least_multiple(a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size = math::integer_least_multiple(
math::integer_least_multiple(b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
...@@ -570,8 +604,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -570,8 +604,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// register // register
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1, MfmaSelector<FloatAB, MPerXDL, NPerXDL>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment