Commit 360184cd authored by Chao Liu's avatar Chao Liu
Browse files

update examples

parent f6922d3f
......@@ -26,7 +26,7 @@ void print_helper_msg()
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <in_n_hi_wi_c image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
......@@ -90,23 +90,27 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
input_right_pads};
}
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
int run_conv_fwd_nhwc(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
int run_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
......@@ -117,37 +121,92 @@ int run_conv_fwd_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths);
};
Tensor<InDataType> in_n_hi_wi_c(
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<WeiDataType> wei_k_y_x_c(
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> out_n_ho_wo_k_host(
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
Tensor<OutDataType> out_n_ho_wo_k_device(
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
auto f_nchw_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)};
nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
return HostTensorDescriptor(nchw_lengths);
};
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> out_host(out_desc);
Tensor<OutDataType> out_device(out_desc);
std::cout << "in_n_hi_wi_c: " << in_n_hi_wi_c.mDesc << std::endl;
std::cout << "wei_k_y_x_c: " << wei_k_y_x_c.mDesc << std::endl;
std::cout << "output: " << out_n_ho_wo_k_host.mDesc << std::endl;
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
// do GEMM
auto conv = DeviceConvNDFwdInstance{};
......@@ -188,31 +247,21 @@ int run_conv_fwd_nhwc(bool do_verification,
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c,
wei_k_y_x_c,
out_n_ho_wo_k_host,
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
......@@ -223,13 +272,10 @@ int run_conv_fwd_nhwc(bool do_verification,
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_n_ho_wo_k_device.mData.data());
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_n_ho_wo_k_host.mData,
out_n_ho_wo_k_device.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
return ck::utils::check_err(
out_host.mData, out_device.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0
: 1;
}
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_fwd_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_fwd_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_fwd_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_fwd_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_fwd_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_fwd_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_fwd_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_fwd_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_fwd_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_fwd_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_fwd_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_fwd_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_fwd_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_fwd_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_fwd_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_fwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNDFwdInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -90,24 +90,27 @@ void print_helper_msg()
<< std::endl;
}
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNdBwdDataInstance>
int run_conv_bwd_data_nhwc(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
int run_conv_bwd_data(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
......@@ -118,37 +121,92 @@ int run_conv_bwd_data_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths);
};
Tensor<InDataType> in_n_hi_wi_c_host(
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<InDataType> in_n_hi_wi_c_device(
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<WeiDataType> wei_k_y_x_c(
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> out_n_ho_wo_k(
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
auto f_nchw_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)};
nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
return HostTensorDescriptor(nchw_lengths);
};
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
std::cout << "in_n_hi_wi_c: " << in_n_hi_wi_c_host.mDesc << std::endl;
std::cout << "wei_k_y_x_c: " << wei_k_y_x_c.mDesc << std::endl;
std::cout << "out_n_ho_wo_k: " << out_n_ho_wo_k.mDesc << std::endl;
Tensor<InDataType> in_host(in_desc);
Tensor<InDataType> in_device(in_desc);
Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "in: " << in_host.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei_k_y_x_c.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
out.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c_device.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_device_buf.ToDevice(out.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
// reset input to zero
in_device_buf.SetZero();
......@@ -194,34 +252,22 @@ int run_conv_bwd_data_nhwc(bool do_verification,
if(do_verification)
{
std::cout << "before ref" << std::endl;
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c_host,
wei_k_y_x_c,
out_n_ho_wo_k,
auto ref_argument = ref_conv.MakeArgument(in_host,
wei,
out,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
......@@ -230,15 +276,11 @@ int run_conv_bwd_data_nhwc(bool do_verification,
wei_element_op,
out_element_op);
std::cout << "before ref" << std::endl;
ref_invoker.Run(ref_argument);
std::cout << "after ref" << std::endl;
in_device_buf.FromDevice(in_n_hi_wi_c_device.mData.data());
in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_n_hi_wi_c_device.mData, in_n_hi_wi_c_host.mData) ? 0 : 1;
return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1;
}
return 0;
......
......@@ -95,57 +95,63 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_bwd_data_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_bwd_data<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 2)
{
return run_conv_bwd_data_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_bwd_data<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
else if(num_dim_spatial == 3)
{
return run_conv_bwd_data_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
return run_conv_bwd_data<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvNdBwdDataInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op);
}
return 0;
......
......@@ -90,24 +90,28 @@ parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
input_right_pads};
}
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvBwdWeightInstance>
int run_conv_bwd_weight_nhwc(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op,
ck::index_t split_k)
int run_conv_bwd_weight(bool do_verification,
int init_method,
bool time_kernel,
const ck::tensor_operation::device::ConvParams& params,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op,
ck::index_t split_k)
{
// make host tensor descritpor
auto f_nhwc_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nhwc_lengths{static_cast<std::size_t>(n),
......@@ -118,38 +122,92 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
return HostTensorDescriptor(nhwc_lengths);
};
Tensor<InDataType> in_n_hi_wi_c(
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_));
Tensor<WeiDataType> wei_k_y_x_c_host_result(
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<WeiDataType> wei_k_y_x_c_device_result(
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_));
Tensor<OutDataType> out_n_ho_wo_k(
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths()));
auto f_nchw_host_tensor_descriptor =
[](ck::index_t n, ck::index_t c, std::vector<ck::index_t> spatial_lengths) {
std::vector<std::size_t> nchw_lengths{static_cast<std::size_t>(n),
static_cast<std::size_t>(c)};
nchw_lengths.insert(nchw_lengths.end(), spatial_lengths.begin(), spatial_lengths.end());
return HostTensorDescriptor(nchw_lengths);
};
HostTensorDescriptor in_desc, wei_desc, out_desc;
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::NCDHW>)
{
in_desc =
f_nchw_host_tensor_descriptor(params.N_, params.C_, params.input_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc =
f_nhwc_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCYX> ||
ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KCZYX>)
{
wei_desc =
f_nchw_host_tensor_descriptor(params.K_, params.C_, params.filter_spatial_lengths_);
}
// FIXME: properly implement "make host descriptor" for different layout
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc =
f_nhwc_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKHW> ||
ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NKDHW>)
{
out_desc =
f_nchw_host_tensor_descriptor(params.N_, params.K_, params.GetOutputSpatialLengths());
}
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei_host_result(wei_desc);
Tensor<WeiDataType> wei_device_result(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "input: " << in_n_hi_wi_c.mDesc << std::endl;
std::cout << "weight: " << wei_k_y_x_c_host_result.mDesc << std::endl;
std::cout << "output: " << out_n_ho_wo_k.mDesc << std::endl;
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei_host_result.mDesc << std::endl;
std::cout << "out: " << out.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
break;
default:
in_n_hi_wi_c.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
wei_k_y_x_c_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
out_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
in_device_buf.ToDevice(in.mData.data());
out_device_buf.ToDevice(out.mData.data());
// init to 0
wei_device_buf.SetZero();
......@@ -197,32 +255,22 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<
2,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>{};
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>{};
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in_n_hi_wi_c,
wei_k_y_x_c_host_result,
out_n_ho_wo_k,
auto ref_argument = ref_conv.MakeArgument(in,
wei_host_result,
out,
params.conv_filter_strides_,
params.conv_filter_dilations_,
params.input_left_pads_,
......@@ -233,11 +281,9 @@ int run_conv_bwd_weight_nhwc(bool do_verification,
ref_invoker.Run(ref_argument);
wei_device_buf.FromDevice(wei_k_y_x_c_device_result.mData.data());
wei_device_buf.FromDevice(wei_device_result.mData.data());
return ck::utils::check_err(wei_k_y_x_c_device_result.mData, wei_k_y_x_c_host_result.mData)
? 0
: 1;
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData) ? 0 : 1;
}
return 0;
......
......@@ -104,60 +104,66 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_bwd_weight_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(num_dim_spatial == 2)
{
return run_conv_bwd_weight_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(num_dim_spatial == 3)
{
return run_conv_bwd_weight_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
return 0;
......
......@@ -103,60 +103,66 @@ int main(int argc, char* argv[])
if(num_dim_spatial == 1)
{
return run_conv_bwd_weight_nhwc<1,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(num_dim_spatial == 2)
{
return run_conv_bwd_weight_nhwc<2,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
else if(num_dim_spatial == 3)
{
return run_conv_bwd_weight_nhwc<3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
return run_conv_bwd_weight<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method,
time_kernel,
params,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment