Unverified Commit 6717168c authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Patch for bwd data comments (#174)

* change function name and way to set input zero

* change enable if
parent 781cacd2
......@@ -83,7 +83,7 @@ using ReferenceConvBwdDataInstance =
OutElementOp,
NumDimSpatial>;
void PrintUseMsg()
void print_use_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
......@@ -99,7 +99,7 @@ void PrintUseMsg()
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
{
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
ck::conv_util::ConvParams params;
......@@ -144,7 +144,7 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[])
return params;
}
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims,
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
......@@ -165,7 +165,7 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>
}
}
}
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
......@@ -187,7 +187,7 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_
}
}
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector<std::size_t>& dims,
int num_dim_spatial = 2)
{
namespace tl = ck::tensor_layout::convolution;
......@@ -210,7 +210,7 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t
}
}
DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial)
DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
{
switch(num_dim_spatial)
{
......@@ -256,15 +256,15 @@ int main(int argc, char* argv[])
int cmdline_nargs = conv_args + 5;
if(cmdline_nargs != argc)
{
PrintUseMsg();
print_use_msg();
exit(1);
}
params = ParseConvParams(num_dim_spatial, argv);
params = parse_conv_params(num_dim_spatial, argv);
}
else if(argc != 1)
{
PrintUseMsg();
print_use_msg();
exit(1);
}
......@@ -288,11 +288,13 @@ int main(int argc, char* argv[])
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<InDataType> in_n_c_hi_wi_device_result(
GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
Tensor<WeiDataType> wei_k_c_y_x(
get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
Tensor<OutDataType> out_n_k_ho_wo(
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
......@@ -318,11 +320,10 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
in_device_buf.SetZero();
// do GEMM
auto conv = GetConvInstance(num_dim_spatial);
auto conv = get_conv_instance(num_dim_spatial);
auto invoker = conv->MakeInvokerPointer();
auto argument =
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
......
......@@ -917,21 +917,21 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
} // function end
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
}
template <ck::index_t NDim, typename std::enable_if<NDim == 3, bool>::type = false>
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
......
......@@ -19,7 +19,7 @@ template <typename InDataType,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator
{
// Argument
......
......@@ -336,8 +336,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero
input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(input_device_result.mData.data());
in_device_buf.SetZero();
if(do_verification)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment