"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "50f5adb09ba1a0217d7bf7ef7462847aac5e0c41"
Unverified Commit 6717168c authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Patch for bwd data comments (#174)

* change function name and way to set input zero

* change enable if
parent 781cacd2
...@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance = ...@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance =
OutElementOp, OutElementOp,
NumDimSpatial>; NumDimSpatial>;
void PrintUseMsg() void print_use_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
...@@ -99,7 +99,7 @@ void PrintUseMsg() ...@@ -99,7 +99,7 @@ void PrintUseMsg()
<< " <right padding>, (ie RightPy, RightPx for 2D)\n" << " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl; << std::endl;
} }
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{ {
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::conv_util::ConvParams params; ck::conv_util::ConvParams params;
...@@ -144,8 +144,8 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) ...@@ -144,8 +144,8 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
return params; return params;
} }
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2) int num_dim_spatial = 2)
{ {
namespace tl = ck::tensor_layout::convolution; namespace tl = ck::tensor_layout::convolution;
...@@ -165,8 +165,8 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t> ...@@ -165,8 +165,8 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>
} }
} }
} }
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2) int num_dim_spatial = 2)
{ {
namespace tl = ck::tensor_layout::convolution; namespace tl = ck::tensor_layout::convolution;
...@@ -187,8 +187,8 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_ ...@@ -187,8 +187,8 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_
} }
} }
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2) int num_dim_spatial = 2)
{ {
namespace tl = ck::tensor_layout::convolution; namespace tl = ck::tensor_layout::convolution;
...@@ -210,7 +210,7 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t ...@@ -210,7 +210,7 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t
} }
} }
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial) DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
{ {
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
...@@ -256,15 +256,15 @@ int main(int argc, char* argv[]) ...@@ -256,15 +256,15 @@ int main(int argc, char* argv[])
int cmdline_nargs = conv_args + 5; int cmdline_nargs = conv_args + 5;
if(cmdline_nargs != argc) if(cmdline_nargs != argc)
{ {
PrintUseMsg(); print_use_msg();
exit(1); exit(1);
} }
params = ParseConvParams(num_dim_spatial, argv); params = parse_conv_params(num_dim_spatial, argv);
} }
else if(argc != 1) else if(argc != 1)
{ {
PrintUseMsg(); print_use_msg();
exit(1); exit(1);
} }
...@@ -288,11 +288,13 @@ int main(int argc, char* argv[]) ...@@ -288,11 +288,13 @@ int main(int argc, char* argv[])
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result( Tensor<InDataType> in_n_c_hi_wi_host_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<InDataType> in_n_c_hi_wi_device_result( Tensor<InDataType> in_n_c_hi_wi_device_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); Tensor<WeiDataType> wei_k_c_y_x(
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
...@@ -318,11 +320,10 @@ int main(int argc, char* argv[]) ...@@ -318,11 +320,10 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero // reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); in_device_buf.SetZero();
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// do GEMM // do GEMM
auto conv = GetConvInstance(num_dim_spatial); auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer(); auto invoker = conv->MakeInvokerPointer();
auto argument = auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
......
...@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} // function end } // function end
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
} }
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
......
...@@ -18,8 +18,8 @@ template <typename InDataType, ...@@ -18,8 +18,8 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2, ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false> typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
......
...@@ -336,8 +336,7 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -336,8 +336,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
wei_device_buf.ToDevice(weights.mData.data()); wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero // reset input to zero
input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); in_device_buf.SetZero();
in_device_buf.ToDevice(input_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment