Commit b6a07385 authored by wangshaojie6's avatar wangshaojie6
Browse files

add profiler code

parent c1062da8
......@@ -18,6 +18,7 @@ set(PROFILER_SOURCE
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_convnd_fwd.cpp
src/profile_convnd_bwd_data.cpp
src/profile_conv_bwd_weight.cpp
src/profile_convnd_bwd_weight.cpp
src/profile_reduce.cpp
src/profile_normalization.cpp
......
......@@ -82,12 +82,12 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial)
{
const int preParams = 10;
const int preParams = 11;
int conv_args = 3 + num_dim_spatial * 6;
int cmdline_nargs = conv_args + preParams;
if(cmdline_nargs != argc)
{
printf("arg1: tensor operation (conv[1|2|3]d_bwd_weight: BackwardConvolution)\n");
printf("arg1: tensor operation (convnd[1|2|3]d_bwd_weight: BackwardConvolution)\n");
printf("arg2: data type (0: fp32; 1: fp16, 2: bf16)\n");
printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n");
printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n");
......@@ -96,7 +96,8 @@ int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial)
printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg9: time kernel (0=n0, 1=yes)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
printf("arg10: splitk\n");
printf("arg11 to 25: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
return 1;
}
......@@ -110,6 +111,9 @@ int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial)
const bool do_log = std::stoi(argv[8]);
const bool time_kernel = std::stoi(argv[9]);
ck::index_t split_k = std::stoi(argv[10]);
split_k = std::max(1, split_k);
ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams);
auto Run = [&](auto input_type, auto wei_type, auto out_type) {
......@@ -140,12 +144,59 @@ int profile_convnd_bwd_weight(int argc, char* argv[], int num_dim_spatial)
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_);
params.input_right_pads_,
split_k);
break;
case 2: break;
case 2:
ck::profiler::profile_convnd_bwd_weight_impl<2,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
do_verification,
init_method,
do_log,
time_kernel,
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
params.GetOutputSpatialLengths(),
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
split_k);
break;
case 3: break;
case 3:
ck::profiler::profile_convnd_bwd_weight_impl<3,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>(
do_verification,
init_method,
do_log,
time_kernel,
params.N_,
params.K_,
params.C_,
params.input_spatial_lengths_,
params.filter_spatial_lengths_,
params.GetOutputSpatialLengths(),
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
params.input_right_pads_,
split_k);
break;
default: break;
}
......
......@@ -118,15 +118,15 @@ int main(int argc, char* argv[])
{
return profile_conv_bwd_weight(argc, argv);
}
else if(strcmp(argv[1], "conv1d_bwd_weight") == 0)
else if(strcmp(argv[1], "convnd1d_bwd_weight") == 0)
{
return profile_convnd_bwd_weight(argc, argv, 1);
}
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
else if(strcmp(argv[1], "convnd2d_bwd_weight") == 0)
{
return profile_convnd_bwd_weight(argc, argv, 2);
}
else if(strcmp(argv[1], "conv3d_bwd_weight") == 0)
else if(strcmp(argv[1], "convnd3d_bwd_weight") == 0)
{
return profile_convnd_bwd_weight(argc, argv, 3);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment