Commit 5cfd01fd authored by ltqin's avatar ltqin
Browse files

format

parent 1e9c511c
...@@ -77,15 +77,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -77,15 +77,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, wei_k_c_y_x_desc,
in_n_c_hi_wi_desc, in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc, out_n_k_ho_wo_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
...@@ -93,13 +93,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -93,13 +93,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 1, 0, 0>{}, make_tuple(Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}),
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple( make_tuple(
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}));
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
......
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#define USE_MODE 1 #define USE_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 1
enum ConvBackwardWeightAlgo enum ConvBackwardWeightAlgo
{ {
V4R4R2XDLNCHW, V4R4R2XDLNCHW,
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -44,12 +43,12 @@ int main(int argc, char* argv[]) ...@@ -44,12 +43,12 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1])); const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2])); const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]); const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]); const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
const index_t N = std::stoi(argv[7]); const index_t N = std::stoi(argv[7]);
const index_t K = std::stoi(argv[8]); const index_t K = std::stoi(argv[8]);
...@@ -81,12 +80,12 @@ int main(int argc, char* argv[]) ...@@ -81,12 +80,12 @@ int main(int argc, char* argv[])
exit(1); exit(1);
} }
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1])); const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2])); const ConvBackwardWeightAlgo algo = static_cast<ConvBackwardWeightAlgo>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]); const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]); const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -245,7 +244,6 @@ int main(int argc, char* argv[]) ...@@ -245,7 +244,6 @@ int main(int argc, char* argv[])
in_right_pads_dev); in_right_pads_dev);
}; };
#if USE_CONV_WRW_V4R4R2_XDL_NCHW #if USE_CONV_WRW_V4R4R2_XDL_NCHW
if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW)
{ {
...@@ -257,8 +255,8 @@ int main(int argc, char* argv[]) ...@@ -257,8 +255,8 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw(); const auto tmp = f_make_for_device_nchw();
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t, device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
...@@ -275,14 +273,14 @@ int main(int argc, char* argv[]) ...@@ -275,14 +273,14 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_direct_convolution_backward_weights(out, host_direct_convolution_backward_weights(out,
in, in,
wei_host, wei_host,
make_tuple(conv_stride_h, conv_stride_w), make_tuple(conv_stride_h, conv_stride_w),
make_tuple(conv_dilation_h, conv_dilation_w), make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w), make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w), make_tuple(in_right_pad_h, in_right_pad_w),
layout); layout);
check_error(wei_host, wei_device); check_error(wei_host, wei_device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment