Unverified Commit c0e95f62 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Patch for bwd data #134 (#168)

* remove switch for NDimSpatial

* change in, out and wei name

* rename reference thumb function name

* remove test
parent cd167e49
...@@ -71,7 +71,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -71,7 +71,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
auto f_nchw = [&](auto n, auto c, auto wi) { auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2]; std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2]; std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
...@@ -108,7 +108,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -108,7 +108,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in); arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0], arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])( arg.input_.mDesc.GetLengths()[2])(
...@@ -182,7 +182,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -182,7 +182,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2]; std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3]; std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
...@@ -252,7 +252,7 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -252,7 +252,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in); arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0], arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1], arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2], arg.input_.mDesc.GetLengths()[2],
......
...@@ -120,7 +120,6 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz ...@@ -120,7 +120,6 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz
case 1: { case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{}); return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
} }
default: { default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!"); throw std::runtime_error("Unsupported number of spatial dimensions provided!");
} }
...@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, const std::vector<ck::index_t>& input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, const std::vector<ck::index_t>& filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, const std::vector<ck::index_t>& output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, const std::vector<ck::index_t>& conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, const std::vector<ck::index_t>& conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, const std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t> input_right_pads) const std::vector<ck::index_t>& input_right_pads)
{ {
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -304,51 +303,50 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -304,51 +303,50 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::begin(output_spatial_lengths), std::begin(output_spatial_lengths),
std::end(output_spatial_lengths)); std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result( Tensor<InDataType> input_host_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial)); get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<InDataType> in_n_c_hi_wi_device_result( Tensor<InDataType> input_device_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial)); get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<WeiDataType> wei_k_c_y_x( Tensor<WeiDataType> weights(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial)); get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<OutDataType> out_n_k_ho_wo( Tensor<OutDataType> output(
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial)); get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero // reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0}); input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.ToDevice(input_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
auto RunReference = [&](auto& ref_conv) { auto RunReference = [&](auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, auto ref_argument = ref_conv.MakeArgument(input_host_result,
wei_k_c_y_x, weights,
out_n_k_ho_wo, output,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -358,33 +356,7 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -358,33 +356,7 @@ bool profile_convnd_bwd_data_impl(int do_verification,
OutElementOp{}); OutElementOp{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
}; };
switch(NDimSpatial)
{
case 3: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
3>();
RunReference(ref_conv);
break;
}
case 2: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
2>();
RunReference(ref_conv);
break;
}
case 1: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType, auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -392,14 +364,8 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -392,14 +364,8 @@ bool profile_convnd_bwd_data_impl(int do_verification,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
1>(); NDimSpatial>();
RunReference(ref_conv); RunReference(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
} }
// add device Conv instances // add device Conv instances
...@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.FromDevice(input_device_result.mData.data());
if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result)) if(!check_out(input_host_result, input_device_result))
{ {
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
...@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl; std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
} }
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); check_error(input_host_result, input_device_result);
if(do_log) if(do_log)
{ {
std::cout << "in : "; std::cout << "in : ";
show_data_nhwc_layout(out_n_k_ho_wo); show_data_nhwc_layout(output);
std::cout << std::endl; std::cout << std::endl;
std::cout << "wei: "; std::cout << "wei: ";
show_data_nhwc_layout(wei_k_c_y_x); show_data_nhwc_layout(weights);
std::cout << std::endl; std::cout << std::endl;
std::cout << "out_host : "; std::cout << "out_host : ";
show_data_nhwc_layout(in_n_c_hi_wi_host_result); show_data_nhwc_layout(input_host_result);
std::cout << std::endl; std::cout << std::endl;
std::cout << "out_device: "; std::cout << "out_device: ";
show_data_nhwc_layout(in_n_c_hi_wi_device_result); show_data_nhwc_layout(input_device_result);
std::cout << std::endl; std::cout << std::endl;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment