Unverified Commit 4634b120 authored by Shaojie WANG's avatar Shaojie WANG Committed by GitHub
Browse files

fix Issue 291 (#294)

* rename for typeconvert functor

* refine code
parent 15c89e81
...@@ -433,7 +433,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -433,7 +433,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using namespace ck; using namespace ck;
const index_t Di = input_spatial_lengths[0]; const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[2]; const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2]; const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0]; const index_t Do = output_spatial_lengths[0];
...@@ -671,11 +671,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -671,11 +671,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return PadDescriptor_M0_1d(desc, gridSize, blockSize); return PadDescriptor_M0_1d(desc, gridSize, blockSize);
} }
using TypeConvertFunctor = using TypeConvertFp32ToBf16Functor =
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>; ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
using GridwiseUEltwise = using GridwiseUEltwise = GridwiseUnaryElementwise_1D<AccDataType,
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>; InDataType,
GridDesc_M0,
TypeConvertFp32ToBf16Functor,
4>;
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
...@@ -979,19 +982,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -979,19 +982,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto run_conv = [&](const auto& kernel) {
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
arg.p_c_grid_, arg.p_c_grid_,
0, 0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
ave_time = return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -1016,8 +1018,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1016,8 +1018,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType))); sizeof(AccDataType)));
ave_time = return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -1059,7 +1060,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1059,7 +1060,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
// run kernel for type conversion // run kernel for type conversion
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_); void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_); InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
const auto Run_type_convert = [&](const auto& kernel) { const auto run_type_convert = [&](const auto& kernel) {
float elapsed_time = float elapsed_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -1070,14 +1071,15 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1070,14 +1071,15 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
p_c_grid_tmp_bf16_, p_c_grid_tmp_bf16_,
a_grid_desc_m0_, a_grid_desc_m0_,
b_grid_desc_m0_, b_grid_desc_m0_,
TypeConvertFunctor{}); TypeConvertFp32ToBf16Functor{});
return elapsed_time; return elapsed_time;
}; };
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value) if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
{ {
if(has_main_k0_block_loop) auto launch_kernel = [&](auto has_main_k_block_loop) {
{ constexpr bool has_main_loop = has_main_k_block_loop.value;
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel = kernel_gemm_xdlops_bwd_weight<
...@@ -1092,9 +1094,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1092,9 +1094,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; has_main_loop>;
Run(kernel); return run_conv(kernel);
} }
else else
{ {
...@@ -1103,7 +1105,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1103,7 +1105,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
AccDataType, AccDataType,
InDataType, InDataType,
GridDesc_M0, GridDesc_M0,
TypeConvertFunctor>; TypeConvertFp32ToBf16Functor>;
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAddFloatBf16Splitk, GridwiseGemmAtomicAddFloatBf16Splitk,
...@@ -1117,56 +1119,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1117,56 +1119,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; has_main_loop>;
run_bf16_splitk(kernel_conv); float elapsed_time = 0;
ave_time += Run_type_convert(kernel_type_convert); elapsed_time += run_bf16_splitk(kernel_conv);
} elapsed_time += run_type_convert(kernel_type_convert);
return elapsed_time;
} }
else };
{ if(has_main_k0_block_loop)
if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< ave_time = launch_kernel(integral_constant<bool, true>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< ave_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemmAtomicAddFloatBf16Splitk,
ADataType, // TODO: distiguish A/B datatype
AccDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
run_bf16_splitk(kernel);
}
} }
} }
else else
{ {
if(has_main_k0_block_loop) auto launch_kernel = [&](auto has_main_k_block_loop) {
{ constexpr bool has_main_loop = has_main_k_block_loop.value;
if(kbatch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel = kernel_gemm_xdlops_bwd_weight<
...@@ -1181,9 +1155,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1181,9 +1155,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; has_main_loop>;
Run(kernel); return run_conv(kernel);
} }
else else
{ {
...@@ -1199,49 +1173,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1199,49 +1173,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
true>; has_main_loop>;
Run(kernel); return run_conv(kernel);
} }
} };
else if(has_main_k0_block_loop)
{
if(kbatch == 1)
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< ave_time = launch_kernel(integral_constant<bool, true>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_bwd_weight< ave_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation,
InElementwiseOperation,
WeiElementwiseOperation,
remove_reference_t<DeviceOp::Block2CTileMap>,
false>;
Run(kernel);
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment