"examples/vscode:/vscode.git/clone" did not exist on "e3103e171f08361dc7c781685bd8127c8f672249"
Unverified Commit c0e95f62 authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Patch for bwd data #134 (#168)

* remove switch for NDimSpatial

* change in, out and wei name

* rename reference thumb function name

* remove test
parent cd167e49
......@@ -71,7 +71,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
if constexpr(NumDimSpatial == 1)
{
auto f_nchw = [&](auto n, auto c, auto wi) {
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
......@@ -108,7 +108,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_nchw,
make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])(
......@@ -182,7 +182,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
else if constexpr(NumDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto c, auto di, auto hi, auto wi) {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
......@@ -252,7 +252,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_nchw,
make_ParallelTensorFunctor(f_ncdhw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
......
......@@ -120,7 +120,6 @@ HostTensorDescriptor get_output_host_ensor_descriptor(const std::vector<std::siz
case 1: {
return ck::conv_util::GetHostTensorDescriptor(dims, OutLayout{});
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
......@@ -274,13 +273,13 @@ bool profile_convnd_bwd_data_impl(int do_verification,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const std::vector<ck::index_t>& input_spatial_lengths,
const std::vector<ck::index_t>& filter_spatial_lengths,
const std::vector<ck::index_t>& output_spatial_lengths,
const std::vector<ck::index_t>& conv_filter_strides,
const std::vector<ck::index_t>& conv_filter_dilations,
const std::vector<ck::index_t>& input_left_pads,
const std::vector<ck::index_t>& input_right_pads)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -304,51 +303,50 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::begin(output_spatial_lengths),
std::end(output_spatial_lengths));
Tensor<InDataType> in_n_c_hi_wi_host_result(
Tensor<InDataType> input_host_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<InDataType> in_n_c_hi_wi_device_result(
Tensor<InDataType> input_device_result(
get_input_host_tensor_descriptor<InLayout>(input_dims, NDimSpatial));
Tensor<WeiDataType> wei_k_c_y_x(
Tensor<WeiDataType> weights(
get_filters_host_tensor_descriptor<WeiLayout>(filter_dims, NDimSpatial));
Tensor<OutDataType> out_n_k_ho_wo(
Tensor<OutDataType> output(
get_output_host_ensor_descriptor<OutLayout>(output_dims, NDimSpatial));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl;
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "weights: " << weights.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
output.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
output.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) *
in_n_c_hi_wi_device_result.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(weights.mData.data());
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
input_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(input_device_result.mData.data());
if(do_verification)
{
auto RunReference = [&](auto& ref_conv) {
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
wei_k_c_y_x,
out_n_k_ho_wo,
auto ref_argument = ref_conv.MakeArgument(input_host_result,
weights,
output,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -358,48 +356,16 @@ bool profile_convnd_bwd_data_impl(int do_verification,
OutElementOp{});
ref_invoker.Run(ref_argument);
};
switch(NDimSpatial)
{
case 3: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
3>();
RunReference(ref_conv);
break;
}
case 2: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
2>();
RunReference(ref_conv);
break;
}
case 1: {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
1>();
RunReference(ref_conv);
break;
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
}
}
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
NDimSpatial>();
RunReference(ref_conv);
}
// add device Conv instances
......@@ -468,9 +434,9 @@ bool profile_convnd_bwd_data_impl(int do_verification,
if(do_verification)
{
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
in_device_buf.FromDevice(input_device_result.mData.data());
if(!check_out(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result))
if(!check_out(input_host_result, input_device_result))
{
std::cout << "Fail Info: " << conv_ptr->GetTypeString() << std::endl;
......@@ -481,24 +447,24 @@ bool profile_convnd_bwd_data_impl(int do_verification,
std::cout << "Pass Info: " << conv_ptr->GetTypeString() << std::endl;
}
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
check_error(input_host_result, input_device_result);
if(do_log)
{
std::cout << "in : ";
show_data_nhwc_layout(out_n_k_ho_wo);
show_data_nhwc_layout(output);
std::cout << std::endl;
std::cout << "wei: ";
show_data_nhwc_layout(wei_k_c_y_x);
show_data_nhwc_layout(weights);
std::cout << std::endl;
std::cout << "out_host : ";
show_data_nhwc_layout(in_n_c_hi_wi_host_result);
show_data_nhwc_layout(input_host_result);
std::cout << std::endl;
std::cout << "out_device: ";
show_data_nhwc_layout(in_n_c_hi_wi_device_result);
show_data_nhwc_layout(input_device_result);
std::cout << std::endl;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment