Commit c5143bca authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Added backward weight cpu changes in driver

parent 32850b93
...@@ -80,6 +80,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf ...@@ -80,6 +80,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buf
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
if(blockIdx.x*blockDim.x + threadIdx.x == 0)
printf("conv dir %d",conv_dir);
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#define CK_ENABLE_XDLOPS 0 #define CK_ENABLE_XDLOPS 0
#define CK_PARAM_PROBLEM_DIRECTION 0 #define CK_PARAM_PROBLEM_DIRECTION 2
#define CK_PARAM_EPACK_LENGTH 1 #define CK_PARAM_EPACK_LENGTH 1
#define CK_PARAM_TUNABLE_BLOCK_SIZE 64 #define CK_PARAM_TUNABLE_BLOCK_SIZE 64
#define CK_PARAM_TUNABLE_K_PER_BLOCK 32 #define CK_PARAM_TUNABLE_K_PER_BLOCK 32
#define CK_PARAM_TUNABLE_B_PER_BLOCK 64 #define CK_PARAM_TUNABLE_B_PER_BLOCK 64
#define CK_PARAM_TUNABLE_E_PER_BLOCK 8 #define CK_PARAM_TUNABLE_E_PER_BLOCK 8
#define CK_PARAM_DEPENDENT_GRID_SIZE 16 #define CK_PARAM_DEPENDENT_GRID_SIZE 2
#define CK_PARAM_GEMM_M_PER_WAVE 32 #define CK_PARAM_GEMM_M_PER_WAVE 32
#define CK_PARAM_GEMM_N_PER_WAVE 64 #define CK_PARAM_GEMM_N_PER_WAVE 64
#define CK_PARAM_IN_BLOCK_COPY_CLUSTER_LENGTHS_E 8 #define CK_PARAM_IN_BLOCK_COPY_CLUSTER_LENGTHS_E 8
...@@ -109,17 +109,21 @@ void device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw(InDesc, ...@@ -109,17 +109,21 @@ void device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw(InDesc,
// of the wrw convolution when used in a fwd context // of the wrw convolution when used in a fwd context
printf("backward weight is executed\n"); printf("backward weight is executed\n");
constexpr auto tmp_in_nchw_desc = // constexpr auto tmp_in_nchw_desc =
make_ConstantTensorDescriptor_packed(Sequence<N, C, Hi, Wi>{}); // make_ConstantTensorDescriptor_packed(Sequence<N, C, Hi, Wi>{});
constexpr auto tmp_wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); // constexpr auto tmp_wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
constexpr auto tmp_out_nkhw_desc = // constexpr auto tmp_out_nkhw_desc =
make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{}); // make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
constexpr auto in_nchw_desc = tmp_in_nchw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{}); // constexpr auto in_nchw_desc = tmp_in_nchw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
// wei and out are swapped in the solver // // wei and out are swapped in the solver
constexpr auto wei_kcyx_desc = tmp_out_nkhw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{}); // constexpr auto wei_kcyx_desc = tmp_out_nkhw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
constexpr auto out_nkhw_desc = tmp_wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{}); // constexpr auto out_nkhw_desc = tmp_wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
constexpr auto dir = ImplicitGemmDirection::BackwardWeight; constexpr auto dir = ImplicitGemmDirection::BackwardWeight;
constexpr auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, Hi, Wi>{});
constexpr auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
constexpr auto out_nkhw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
// swap stride and dilation // swap stride and dilation
// using ConvDilations = Sequence<ConvStrideH, ConvStrideW>; // using ConvDilations = Sequence<ConvStrideH, ConvStrideW>;
// using ConvStrides = Sequence<ConvDilationH, ConvDilationW>; // using ConvStrides = Sequence<ConvDilationH, ConvDilationW>;
......
...@@ -96,7 +96,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -96,7 +96,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data()); in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data());
wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data()); wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 0
// 3x3, 34x34, 128 thread, fp32, vector = 1 // 3x3, 34x34, 128 thread, fp32, vector = 1
constexpr index_t NPerBlock = 2; constexpr index_t NPerBlock = 2;
......
...@@ -17,6 +17,10 @@ ...@@ -17,6 +17,10 @@
using namespace ck; using namespace ck;
#define CONV_DIRECTION_FWD_DATA 0
#define CONV_DIRECTION_BWD_DATA 0
#define CONV_DIRECTION_BWD_WEIT 1
struct GeneratorTensor_1 struct GeneratorTensor_1
{ {
template <class... Is> template <class... Is>
...@@ -29,7 +33,7 @@ struct GeneratorTensor_1 ...@@ -29,7 +33,7 @@ struct GeneratorTensor_1
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
int min_value = 0; int min_value = 0;
int max_value = 1; int max_value = 16;
template <class... Is> template <class... Is>
double operator()(Is...) double operator()(Is...)
...@@ -50,13 +54,27 @@ struct GeneratorTensor_3 ...@@ -50,13 +54,27 @@ struct GeneratorTensor_3
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc); return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
} }
}; };
struct GeneratorTensor_fixed
{
template <class... Is>
double operator()(Is... is)
{
std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};
if(dims[0] == 0)
return (dims[1]*16 + dims[2]*4 + dims[3]);
else
return 1;
}
};
struct GeneratorTensor_Checkboard struct GeneratorTensor_Checkboard
{ {
template <class... Ts> template <class... Ts>
double operator()(Ts... Xs) const double operator()(Ts... Xs) const
{ {
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}}; std::array<index_t, sizeof...(Ts)> dims = {{static_cast<index_t>(Xs)...}};
return std::accumulate(dims.begin(), return std::accumulate(dims.begin(),
dims.end(), dims.end(),
true, true,
...@@ -401,7 +419,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -401,7 +419,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
float ref_value = 0, result_value = 0; float ref_value = 0, result_value = 0;
for(int i = 0; i < ref.mData.size(); ++i) for(int i = 0; i < ref.mData.size(); ++i)
{ {
std::cout << result.mData[i] << " "; std::cout << result.mData[i] << "," << ref.mData[i] << " ";
error += std::abs(double(ref.mData[i]) - double(result.mData[i])); error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff) if(max_diff < diff)
...@@ -819,15 +837,24 @@ int main(int argc, char* argv[]) ...@@ -819,15 +837,24 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
constexpr index_t HO = 4;
constexpr index_t WO = 4;
#endif #endif
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{}; auto upper_pads = Sequence<HPad, WPad>{};
#if CONV_DIRECTION_FWD_DATA
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads); in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads);
#elif CONV_DIRECTION_BWD_WEIT
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<C, N, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<C, K, Y, X>{});
auto out_nkhw_desc = make_ConstantTensorDescriptor_packed(Sequence<K, N, HO, WO>{});
#endif
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
...@@ -835,10 +862,19 @@ int main(int argc, char* argv[]) ...@@ -835,10 +862,19 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using out_data_t = float; using out_data_t = float;
#if CONV_DIRECTION_FWD_DATA
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
#elif CONV_DIRECTION_BWD_WEIT
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<out_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> wei_kcyx_host(make_TensorDescriptor(wei_kcyx_desc));
Tensor<in_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<in_data_t> out_nkhw(make_TensorDescriptor(out_nkhw_desc));
#endif
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -854,8 +890,14 @@ int main(int argc, char* argv[]) ...@@ -854,8 +890,14 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 1 #if 1
#if CONV_DIRECTION_FWD_DATA // fwd data
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
in_nchw.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
//out_nkhw_host.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
#endif
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
...@@ -891,6 +933,7 @@ int main(int argc, char* argv[]) ...@@ -891,6 +933,7 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw
#endif #endif
#if CONV_DIRECTION_FWD_DATA // fwd data
(in_nchw_desc, (in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -900,6 +943,17 @@ int main(int argc, char* argv[]) ...@@ -900,6 +943,17 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
(in_nchw_desc,
in_nchw,
out_nkhw_desc,
out_nkhw,
wei_kcyx_desc,
wei_kcyx,
ConvDilations{},
ConvStrides{},
nrepeat);
#endif
#elif 0 #elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
...@@ -924,6 +978,8 @@ int main(int argc, char* argv[]) ...@@ -924,6 +978,8 @@ int main(int argc, char* argv[])
else else
#endif #endif
{ {
#if CONV_DIRECTION_FWD_DATA // fwd data
host_direct_convolution(in_nchw, host_direct_convolution(in_nchw,
wei_kcyx, wei_kcyx,
out_nkhw_host, out_nkhw_host,
...@@ -931,8 +987,25 @@ int main(int argc, char* argv[]) ...@@ -931,8 +987,25 @@ int main(int argc, char* argv[])
ConvDilations{}, ConvDilations{},
lower_pads, lower_pads,
upper_pads); upper_pads);
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
host_direct_convolution(in_nchw,
out_nkhw,
wei_kcyx_host,
ConvDilations{},
ConvStrides{},
lower_pads,
upper_pads);
#endif
} }
#if CONV_DIRECTION_FWD_DATA // fwd data
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
check_error(wei_kcyx_host, wei_kcyx);
#endif
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device : ", out_nkhw.mData, ",") << std::endl;
//LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
#if 0 #if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment