Commit aa432252 authored by Chao Liu's avatar Chao Liu
Browse files

fix bug

parent b6950a3c
...@@ -701,21 +701,68 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -701,21 +701,68 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr index_t MPerBlock_CCopy = MWave * MPerXdl; constexpr index_t MPerBlock_CCopy = MWave * MPerXdl;
constexpr index_t NPerBlock_CCopy = NWave * NPerXdl; constexpr index_t NPerBlock_CCopy = NWave * NPerXdl;
// TODO: hacky // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = // TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MThread_CCopy = 32;
constexpr index_t NThread_CCopy = 8;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{}));
static_assert(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize() == 64 * 64,
"wrong!");
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAcc*>(p_shared),
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_pass_through_transform(I1), // M0 (MRepeat) per shuffle = 1
make_unmerge_transform(
make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform(I0), // freeze nblock
make_pass_through_transform(I1), // N0 (NRepeat) per shuffle = 1
make_unmerge_transform(
make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<>{},
Sequence<0>{},
Sequence<2, 4, 5, 6>{},
Sequence<>{},
Sequence<1>{},
Sequence<3, 7>{})
);
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -725,23 +772,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -725,23 +772,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx = const auto m_thread_data_on_block_idx =
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block)); make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), make_single_stage_tensor_adaptor(
make_tuple(Sequence<0, 1, 2>{}), make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx = const auto n_thread_data_on_block_idx =
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// VGPR to LDS // VGPR to LDS
...@@ -769,26 +817,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -769,26 +817,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MThread_CCopy = 32;
constexpr index_t NThread_CCopy = 8;
constexpr index_t MPerThread_CCopy = MPerBlock_CCopy / MThread_CCopy;
constexpr index_t NPerThread_CCopy = NPerBlock_CCopy / NThread_CCopy;
constexpr auto c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, Number<MPerBlock_CCopy>{}, I1, I1, Number<NPerBlock_CCopy>{}));
static_assert(c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize() == 64 * 64,
"wrong!");
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
static_cast<FloatAcc*>(p_shared),
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.GetElementSpaceSize());
auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4< auto c_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v4<
BlockSize, // index_t BlockSize, BlockSize, // index_t BlockSize,
ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, // SrcElementwiseOperation,
...@@ -843,30 +871,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -843,30 +871,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf); c_block_buf);
#if 0
if(get_thread_local_1d_id() == 0)
{
for(int mwave = 0; mwave < MWave; ++mwave)
{
for(int mperxdl = 0; mperxdl < MPerXdl; ++mperxdl)
{
for(int nwave = 0; nwave < NWave; ++nwave)
{
for(int nperxdl = 0; nperxdl < NPerXdl; ++nperxdl)
{
int m = mwave * MPerXdl + mperxdl;
int n = nwave * NPerXdl + nperxdl;
int offset = m * NWave * NPerXdl + n;
c_block_buf(offset) = 10 * mwave + nwave;
}
}
}
}
}
#endif
// make sure ds_write from c_thread_copy_vgpr_to_lds is completed // make sure ds_write from c_thread_copy_vgpr_to_lds is completed
block_sync_lds(); block_sync_lds();
...@@ -887,7 +891,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -887,7 +891,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
nrepeat_forward_step); nrepeat_forward_step);
} }
else if constexpr((!nrepeat_forward_sweep) && (nrepeat > 1)) else if constexpr((!nrepeat_forward_sweep) && (nrepeat > 0))
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl, c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl,
......
...@@ -70,6 +70,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -70,6 +70,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
...@@ -419,6 +421,27 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N ...@@ -419,6 +421,27 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout
<< "arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_{ "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I0)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I1)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I2)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I3)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I4)
<< ", "
<< arg.c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
.GetLength(I5)
<< "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
......
...@@ -198,10 +198,6 @@ int main(int argc, char* argv[]) ...@@ -198,10 +198,6 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{});
break;
case 2:
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
...@@ -276,15 +272,5 @@ int main(int argc, char* argv[]) ...@@ -276,15 +272,5 @@ int main(int argc, char* argv[])
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result);
{
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "wei: ", wei_k_c_y_x.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "out_host : ", out_n_k_ho_wo_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "out_device: ", out_n_k_ho_wo_device_result.mData, ",")
<< std::endl;
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment