Commit 9b3c4ac4 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 1d784873 7843a8a7
...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s, BatchStrideD1s,
BatchStrideE1} BatchStrideE1}
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " {
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", " << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
<< std::endl; << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " << std::endl;
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
#endif << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
static_for<0, NumD0Tensor, 1>{}([&](auto i) { static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>; using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
......
...@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO ...@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; {
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) std::cout << "arg.reduce_grid_desc_m_{ "
<< "}" << std::endl; << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
}
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
arg.Print(); {
#endif arg.Print();
}
if(!ck::is_xdl_supported()) if(!ck::is_xdl_supported())
{ {
......
...@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" {
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< std::endl; << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< std::endl; << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< std::endl; << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
}
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -644,7 +644,7 @@ struct ...@@ -644,7 +644,7 @@ struct
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -664,9 +664,7 @@ struct ...@@ -664,9 +664,7 @@ struct
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -684,7 +682,6 @@ struct ...@@ -684,7 +682,6 @@ struct
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X ...@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << DeviceOp{}.GetTypeString() << std::endl;
std::cout << "N " << arg.Conv_N_ << ", " std::cout << "N " << arg.Conv_N_ << ", "
...@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< arg.input_left_pads_[1] << ", " << std::endl; << arg.input_left_pads_[1] << ", " << std::endl;
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
<< arg.input_right_pads_[1] << ", " << std::endl; << arg.input_right_pads_[1] << ", " << std::endl;
}
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
...@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
.GetLength(I5) .GetLength(I5)
<< "}" << std::endl; << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
{ {
......
...@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
......
...@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl; << " ) " << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
...@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported())) ck::is_gfx11_supported()))
{ {
return false; return false;
} }
......
...@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m_k1{" std::cout << "arg.a_grid_desc_k0_m_k1{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
...@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl; << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i],
......
...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
...@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
...@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
} }
} }
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_navi3_supported()) ck::is_gfx11_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout, ...@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& karg)
{ {
if(ck::is_navi2_supported() || ck::is_navi3_supported()) if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
return GridwiseGemm::CheckValidity(karg); return GridwiseGemm::CheckValidity(karg);
} }
......
...@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout, ...@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported()) ck::is_gfx103_supported() || ck::is_gfx11_supported())
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
...@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::is_navi3_supported()) if(ck::is_gfx11_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio ...@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}" << std::endl; << "}" << std::endl;
} }
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment