Commit 360184cd authored by Chao Liu's avatar Chao Liu
Browse files

update examples

parent f6922d3f
...@@ -26,7 +26,7 @@ void print_helper_msg() ...@@ -26,7 +26,7 @@ void print_helper_msg()
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <in_n_hi_wi_c image spatial dimensions>, (ie Hi, Wi for 2D)\n" << " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n" << " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n" << " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n" << " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
...@@ -90,16 +90,19 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) ...@@ -90,16 +90,19 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
input_right_pads}; input_right_pads};
} }
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename DeviceConvNDFwdInstance> typename DeviceConvNDFwdInstance>
int run_conv_fwd_nhwc(bool do_verification, int run_conv_fwd(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::tensor_operation::device::ConvParams& params, const ck::tensor_operation::device::ConvParams& params,
...@@ -107,6 +110,7 @@ int run_conv_fwd_nhwc(bool do_verification, ...@@ -107,6 +110,7 @@ int run_conv_fwd_nhwc(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor = auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) { [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n), std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
...@@ -117,37 +121,92 @@ int run_conv_fwd_nhwc(bool do_verification, ...@@ -117,37 +121,92 @@ int run_conv_fwd_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths); return HostTensorDescriptor(nhwc_lengths);
}; };
Tensor<InDataType> in_n_hi_wi_c( auto f_nchw_host_tensor_descriptor =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
Tensor<WeiDataType> wei_k_y_x_c( std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_)); static_cast<std::size_t>(c)};
Tensor<OutDataType> out_n_ho_wo_k_host( nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
Tensor<OutDataType> out_n_ho_wo_k_device( return HostTensorDescriptor(nchw_lengths);
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); };
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> out_host(out_desc);
Tensor<OutDataType> out_device(out_desc);
std::cout << "in_n_hi_wi_c: " << in_n_hi_wi_c.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei_k_y_x_c: " << wei_k_y_x_c.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "output: " << out_n_ho_wo_k_host.mDesc << std::endl; std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k_device.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
// do GEMM // do GEMM
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
...@@ -188,20 +247,10 @@ int run_conv_fwd_nhwc(bool do_verification, ...@@ -188,20 +247,10 @@ int run_conv_fwd_nhwc(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
NDimSpatial, InLayout,
ck::tuple_element_t<NDimSpatial - 1, WeiLayout,
ck::Tuple<ck::tensor_layout::convolution::NWC, OutLayout,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -210,9 +259,9 @@ int run_conv_fwd_nhwc(bool do_verification, ...@@ -210,9 +259,9 @@ int run_conv_fwd_nhwc(bool do_verification,
OutElementOp>(); OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c, auto ref_argument = ref_conv.MakeArgument(in,
wei_k_y_x_c, wei,
out_n_ho_wo_k_host, out_host,
params.conv_filter_strides_, params.conv_filter_strides_,
params.conv_filter_dilations_, params.conv_filter_dilations_,
params.input_left_pads_, params.input_left_pads_,
...@@ -223,13 +272,10 @@ int run_conv_fwd_nhwc(bool do_verification, ...@@ -223,13 +272,10 @@ int run_conv_fwd_nhwc(bool do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_n_ho_wo_k_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_n_ho_wo_k_host.mData, return ck::utils::check_err(
out_n_ho_wo_k_device.mData, out_host.mData, out_device.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
"Error: incorrect results!",
1e-5f,
1e-4f)
? 0 ? 0
: 1; : 1;
} }
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd_nhwc<1, return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd_nhwc<2, return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd_nhwc<3, return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd_nhwc<1, return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd_nhwc<2, return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd_nhwc<3, return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd_nhwc<1, return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd_nhwc<2, return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd_nhwc<3, return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd_nhwc<1, return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd_nhwc<2, return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd_nhwc<3, return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_fwd_nhwc<1, return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_fwd_nhwc<2, return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_fwd_nhwc<3, return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -90,16 +90,19 @@ void print_helper_msg() ...@@ -90,16 +90,19 @@ void print_helper_msg()
<< std::endl; << std::endl;
} }
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename DeviceConvNdBwdDataInstance> typename DeviceConvNdBwdDataInstance>
int run_conv_bwd_data_nhwc(bool do_verification, int run_conv_bwd_data(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::tensor_operation::device::ConvParams& params, const ck::tensor_operation::device::ConvParams& params,
...@@ -107,7 +110,7 @@ int run_conv_bwd_data_nhwc(bool do_verification, ...@@ -107,7 +110,7 @@ int run_conv_bwd_data_nhwc(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor = auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) { [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n), std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
...@@ -118,37 +121,92 @@ int run_conv_bwd_data_nhwc(bool do_verification, ...@@ -118,37 +121,92 @@ int run_conv_bwd_data_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths); return HostTensorDescriptor(nhwc_lengths);
}; };
Tensor<InDataType> in_n_hi_wi_c_host( auto f_nchw_host_tensor_descriptor =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
Tensor<InDataType> in_n_hi_wi_c_device( std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); static_cast<std::size_t>(c)};
Tensor<WeiDataType> wei_k_y_x_c( nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> out_n_ho_wo_k( return HostTensorDescriptor(nchw_lengths);
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); };
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
std::cout << "in_n_hi_wi_c: " << in_n_hi_wi_c_host.mDesc << std::endl; Tensor<InDataType> in_host(in_desc);
std::cout << "wei_k_y_x_c: " << wei_k_y_x_c.mDesc << std::endl; Tensor<InDataType> in_device(in_desc);
std::cout << "out_n_ho_wo_k: " << out_n_ho_wo_k.mDesc << std::endl; Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "in: " << in_host.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break; break;
default: default:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}); out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c_device.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_device_buf.ToDevice(out.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
// reset input to zero // reset input to zero
in_device_buf.SetZero(); in_device_buf.SetZero();
...@@ -194,22 +252,10 @@ int run_conv_bwd_data_nhwc(bool do_verification, ...@@ -194,22 +252,10 @@ int run_conv_bwd_data_nhwc(bool do_verification,
if(do_verification) if(do_verification)
{ {
std::cout << "before ref" << std::endl; auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
InLayout,
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData< WeiLayout,
NDimSpatial, OutLayout,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -219,9 +265,9 @@ int run_conv_bwd_data_nhwc(bool do_verification, ...@@ -219,9 +265,9 @@ int run_conv_bwd_data_nhwc(bool do_verification,
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host, auto ref_argument = ref_conv.MakeArgument(in_host,
wei_k_y_x_c, wei,
out_n_ho_wo_k, out,
params.conv_filter_strides_, params.conv_filter_strides_,
params.conv_filter_dilations_, params.conv_filter_dilations_,
params.input_left_pads_, params.input_left_pads_,
...@@ -230,15 +276,11 @@ int run_conv_bwd_data_nhwc(bool do_verification, ...@@ -230,15 +276,11 @@ int run_conv_bwd_data_nhwc(bool do_verification,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
std::cout << "before ref" << std::endl;
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
std::cout << "after ref" << std::endl; in_device_buf.FromDevice(in_device.mData.data());
in_device_buf.FromDevice(in_n_hi_wi_c_device.mData.data());
return ck::utils::check_err(in_n_hi_wi_c_device.mData, in_n_hi_wi_c_host.mData) ? 0 : 1; return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -95,11 +95,13 @@ int main(int argc, char* argv[]) ...@@ -95,11 +95,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_bwd_data_nhwc<1, return run_conv_bwd_data<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -113,11 +115,13 @@ int main(int argc, char* argv[]) ...@@ -113,11 +115,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_bwd_data_nhwc<2, return run_conv_bwd_data<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -131,11 +135,13 @@ int main(int argc, char* argv[]) ...@@ -131,11 +135,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_bwd_data_nhwc<3, return run_conv_bwd_data<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -90,16 +90,19 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) ...@@ -90,16 +90,19 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
input_right_pads}; input_right_pads};
} }
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
typename OutElementOp, typename OutElementOp,
typename DeviceConvBwdWeightInstance> typename DeviceConvBwdWeightInstance>
int run_conv_bwd_weight_nhwc(bool do_verification, int run_conv_bwd_weight(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::tensor_operation::device::ConvParams& params, const ck::tensor_operation::device::ConvParams& params,
...@@ -108,6 +111,7 @@ int run_conv_bwd_weight_nhwc(bool do_verification, ...@@ -108,6 +111,7 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
const OutElementOp& out_element_op, const OutElementOp& out_element_op,
ck::index_t split_k) ck::index_t split_k)
{ {
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor = auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) { [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n), std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
...@@ -118,38 +122,92 @@ int run_conv_bwd_weight_nhwc(bool do_verification, ...@@ -118,38 +122,92 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths); return HostTensorDescriptor(nhwc_lengths);
}; };
Tensor<InDataType> in_n_hi_wi_c( auto f_nchw_host_tensor_descriptor =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_)); [](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
Tensor<WeiDataType> wei_k_y_x_c_host_result( std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_)); static_cast<std::size_t>(c)};
Tensor<WeiDataType> wei_k_y_x_c_device_result( nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> out_n_ho_wo_k( return HostTensorDescriptor(nchw_lengths);
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths())); };
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei_host_result(wei_desc);
Tensor<WeiDataType> wei_device_result(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "input: " << in_n_hi_wi_c.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "weight: " << wei_k_y_x_c_host_result.mDesc << std::endl; std::cout << "wei: " << wei_host_result.mDesc << std::endl;
std::cout << "output: " << out_n_ho_wo_k.mDesc << std::endl; std::cout << "out: " << out.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}); out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break; break;
default: default:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpace());
wei_k_y_x_c_device_result.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); in_device_buf.ToDevice(in.mData.data());
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_device_buf.ToDevice(out.mData.data());
// init to 0 // init to 0
wei_device_buf.SetZero(); wei_device_buf.SetZero();
...@@ -197,20 +255,10 @@ int run_conv_bwd_weight_nhwc(bool do_verification, ...@@ -197,20 +255,10 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight< auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
2, InLayout,
ck::tuple_element_t<NDimSpatial - 1, WeiLayout,
ck::Tuple<ck::tensor_layout::convolution::NWC, OutLayout,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -220,9 +268,9 @@ int run_conv_bwd_weight_nhwc(bool do_verification, ...@@ -220,9 +268,9 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c, auto ref_argument = ref_conv.MakeArgument(in,
wei_k_y_x_c_host_result, wei_host_result,
out_n_ho_wo_k, out,
params.conv_filter_strides_, params.conv_filter_strides_,
params.conv_filter_dilations_, params.conv_filter_dilations_,
params.input_left_pads_, params.input_left_pads_,
...@@ -233,11 +281,9 @@ int run_conv_bwd_weight_nhwc(bool do_verification, ...@@ -233,11 +281,9 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_k_y_x_c_device_result.mData.data()); wei_device_buf.FromDevice(wei_device_result.mData.data());
return ck::utils::check_err(wei_k_y_x_c_device_result.mData, wei_k_y_x_c_host_result.mData) return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData) ? 0 : 1;
? 0
: 1;
} }
return 0; return 0;
......
...@@ -104,11 +104,13 @@ int main(int argc, char* argv[]) ...@@ -104,11 +104,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_bwd_weight_nhwc<1, return run_conv_bwd_weight<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -123,11 +125,13 @@ int main(int argc, char* argv[]) ...@@ -123,11 +125,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_bwd_weight_nhwc<2, return run_conv_bwd_weight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -142,11 +146,13 @@ int main(int argc, char* argv[]) ...@@ -142,11 +146,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_bwd_weight_nhwc<3, return run_conv_bwd_weight<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
...@@ -103,11 +103,13 @@ int main(int argc, char* argv[]) ...@@ -103,11 +103,13 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1) if(num_dim_spatial == 1)
{ {
return run_conv_bwd_weight_nhwc<1, return run_conv_bwd_weight<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -122,11 +124,13 @@ int main(int argc, char* argv[]) ...@@ -122,11 +124,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 2) else if(num_dim_spatial == 2)
{ {
return run_conv_bwd_weight_nhwc<2, return run_conv_bwd_weight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
...@@ -141,11 +145,13 @@ int main(int argc, char* argv[]) ...@@ -141,11 +145,13 @@ int main(int argc, char* argv[])
} }
else if(num_dim_spatial == 3) else if(num_dim_spatial == 3)
{ {
return run_conv_bwd_weight_nhwc<3, return run_conv_bwd_weight<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
AccDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment