Commit 809799bf authored by Chao Liu's avatar Chao Liu
Browse files

update conv bwd-data and bwd-weight

parent 71b69694
...@@ -23,73 +23,7 @@ void print_helper_msg() ...@@ -23,73 +23,7 @@ void print_helper_msg()
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "Following arguments (depending on number of spatial dims):\n" << get_conv_param_parser_helper_msg() << std::endl;
<< " N spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
<< " G, N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
......
...@@ -18,83 +18,15 @@ ...@@ -18,83 +18,15 @@
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
void print_helper_msg() void print_helper_msg()
{ {
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << get_conv_param_parser_helper_msg() << std::endl;
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
} }
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -106,18 +38,17 @@ int run_conv_bwd_data(bool do_verification, ...@@ -106,18 +38,17 @@ int run_conv_bwd_data(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op, const InElementOp& in_element_op,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
const auto in_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param); Tensor<InDataType> in_host(in_g_n_c_wis_desc);
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param); Tensor<InDataType> in_device(in_g_n_c_wis_desc);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param); Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> out(out_g_n_k_wos_desc);
Tensor<InDataType> in_host(in_desc);
Tensor<InDataType> in_device(in_desc);
Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "in: " << in_host.mDesc << std::endl; std::cout << "in: " << in_host.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -187,9 +118,6 @@ int run_conv_bwd_data(bool do_verification, ...@@ -187,9 +118,6 @@ int run_conv_bwd_data(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
......
...@@ -59,15 +59,16 @@ using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBw ...@@ -59,15 +59,16 @@ using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBw
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
namespace ctc = ck::tensor_layout::convolution;
print_helper_msg(); print_helper_msg();
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
int num_dim_spatial = 2;
ck::utils::conv::ConvParam params{ ck::utils::conv::ConvParam conv_param{
2, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1) if(argc == 1)
{ {
...@@ -81,24 +82,34 @@ int main(int argc, char* argv[]) ...@@ -81,24 +82,34 @@ int main(int argc, char* argv[])
} }
else else
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); const ck::index_t num_dim_spatial = std::stoi(argv[4]);
params = parse_conv_params(num_dim_spatial, 5, argv); conv_param = parse_conv_param(num_dim_spatial, 5, argv);
} }
const auto in_element_op = InElementOp{}; const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{}; const auto out_element_op = OutElementOp{};
if(num_dim_spatial == 1) if(conv_param.num_dim_spatial_ == 1)
{ {
using InLayout = ctc::GNWC;
using WeiLayout = ctc::GKXC;
using OutLayout = ctc::GNWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_data<1, return run_conv_bwd_data<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -108,17 +119,30 @@ int main(int argc, char* argv[]) ...@@ -108,17 +119,30 @@ int main(int argc, char* argv[])
DeviceConvNdBwdDataInstance<1>>(do_verification, DeviceConvNdBwdDataInstance<1>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
} }
else if(num_dim_spatial == 2) else if(conv_param.num_dim_spatial_ == 2)
{ {
using InLayout = ctc::GNHWC;
using WeiLayout = ctc::GKYXC;
using OutLayout = ctc::GNHWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_data<2, return run_conv_bwd_data<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -128,17 +152,30 @@ int main(int argc, char* argv[]) ...@@ -128,17 +152,30 @@ int main(int argc, char* argv[])
DeviceConvNdBwdDataInstance<2>>(do_verification, DeviceConvNdBwdDataInstance<2>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
} }
else if(num_dim_spatial == 3) else if(conv_param.num_dim_spatial_ == 3)
{ {
using InLayout = ctc::GNDHWC;
using WeiLayout = ctc::GKZYXC;
using OutLayout = ctc::GNDHWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_data<3, return run_conv_bwd_data<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -148,7 +185,10 @@ int main(int argc, char* argv[]) ...@@ -148,7 +185,10 @@ int main(int argc, char* argv[])
DeviceConvNdBwdDataInstance<3>>(do_verification, DeviceConvNdBwdDataInstance<3>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -23,78 +23,10 @@ void print_helper_msg() ...@@ -23,78 +23,10 @@ void print_helper_msg()
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << get_conv_param_parser_helper_msg() << "split_k" << std::endl;
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< "split_k" << std::endl;
} }
ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
// FIXME: current implementation only support NCHW/NHWC layout
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -106,19 +38,18 @@ int run_conv_bwd_weight(bool do_verification, ...@@ -106,19 +38,18 @@ int run_conv_bwd_weight(bool do_verification,
int init_method, int init_method,
bool time_kernel, bool time_kernel,
const ck::utils::conv::ConvParam& conv_param, const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op, const InElementOp& in_element_op,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op, const OutElementOp& out_element_op,
ck::index_t split_k) ck::index_t split_k)
{ {
const auto in_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param); Tensor<InDataType> in(in_g_n_c_wis_desc);
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param); Tensor<WeiDataType> wei_host_result(wei_g_k_c_xs_desc);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param); Tensor<WeiDataType> wei_device_result(wei_g_k_c_xs_desc);
Tensor<OutDataType> out(out_g_n_k_wos_desc);
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei_host_result(wei_desc);
Tensor<WeiDataType> wei_device_result(wei_desc);
Tensor<OutDataType> out(out_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei_host_result.mDesc << std::endl; std::cout << "wei: " << wei_host_result.mDesc << std::endl;
...@@ -190,9 +121,6 @@ int run_conv_bwd_weight(bool do_verification, ...@@ -190,9 +121,6 @@ int run_conv_bwd_weight(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
......
...@@ -62,15 +62,14 @@ using DeviceConvndBwdWeightInstance = ...@@ -62,15 +62,14 @@ using DeviceConvndBwdWeightInstance =
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
print_helper_msg(); namespace ctc = ck::tensor_layout::convolution;
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
int num_dim_spatial = 2;
ck::utils::conv::ConvParam params{ ck::utils::conv::ConvParam conv_param{
2, 32, 256, 1024, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; 2, 1, 32, 256, 1024, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
ck::index_t split_k = 4; ck::index_t split_k = 4;
...@@ -86,12 +85,12 @@ int main(int argc, char* argv[]) ...@@ -86,12 +85,12 @@ int main(int argc, char* argv[])
} }
else else
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
num_dim_spatial = std::stoi(argv[4]); const ck::index_t num_dim_spatial = std::stoi(argv[4]);
params = parse_conv_params(num_dim_spatial, 5, argv); conv_param = parse_conv_param(num_dim_spatial, 5, argv);
split_k = std::stoi(argv[5 + 3 + 6 * num_dim_spatial - 1]); split_k = std::stoi(argv[5 + 3 + 6 * num_dim_spatial - 1]);
split_k = std::max(1, split_k); split_k = std::max(1, split_k);
...@@ -101,12 +100,22 @@ int main(int argc, char* argv[]) ...@@ -101,12 +100,22 @@ int main(int argc, char* argv[])
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{}; const auto out_element_op = OutElementOp{};
if(num_dim_spatial == 1) if(conv_param.num_dim_spatial_ == 1)
{ {
using InLayout = ctc::GNWC;
using WeiLayout = ctc::GKXC;
using OutLayout = ctc::GNWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_weight<1, return run_conv_bwd_weight<1,
ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::NWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -116,18 +125,31 @@ int main(int argc, char* argv[]) ...@@ -116,18 +125,31 @@ int main(int argc, char* argv[])
DeviceConvndBwdWeightInstance<1>>(do_verification, DeviceConvndBwdWeightInstance<1>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(num_dim_spatial == 2) else if(conv_param.num_dim_spatial_ == 2)
{ {
using InLayout = ctc::GNHWC;
using WeiLayout = ctc::GKYXC;
using OutLayout = ctc::GNHWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_weight<2, return run_conv_bwd_weight<2,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -137,18 +159,31 @@ int main(int argc, char* argv[]) ...@@ -137,18 +159,31 @@ int main(int argc, char* argv[])
DeviceConvndBwdWeightInstance<2>>(do_verification, DeviceConvndBwdWeightInstance<2>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
split_k); split_k);
} }
else if(num_dim_spatial == 3) else if(conv_param.num_dim_spatial_ == 3)
{ {
using InLayout = ctc::GNDHWC;
using WeiLayout = ctc::GKZYXC;
using OutLayout = ctc::GNDHWK;
const auto in_g_n_c_wis_desc =
make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
return run_conv_bwd_weight<3, return run_conv_bwd_weight<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -158,7 +193,10 @@ int main(int argc, char* argv[]) ...@@ -158,7 +193,10 @@ int main(int argc, char* argv[])
DeviceConvndBwdWeightInstance<3>>(do_verification, DeviceConvndBwdWeightInstance<3>>(do_verification,
init_method, init_method,
time_kernel, time_kernel,
params, conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
......
...@@ -22,73 +22,7 @@ void print_helper_msg() ...@@ -22,73 +22,7 @@ void print_helper_msg()
std::cout << "arg1: verification (0=no, 1=yes)\n" std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "Following arguments (depending on number of spatial dims):\n" << get_conv_param_parser_helper_msg() << std::endl;
<< " N spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
<< " G, N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
<< " <dilations>, (ie Dy, Dx for 2D)\n"
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
<< std::endl;
}
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
......
...@@ -14,10 +14,8 @@ namespace ck { ...@@ -14,10 +14,8 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -72,71 +70,21 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -72,71 +70,21 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
using Argument = ReferenceConvBwdData::Argument; using Argument = ReferenceConvBwdData::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
HostTensorDescriptor in_desc = arg.input_.mDesc; arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
HostTensorDescriptor wei_desc = arg.weight_.mDesc; arg.output_.GetNumOfDimension() == NDimSpatial + 3))
HostTensorDescriptor out_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{ {
in_desc = transpose_host_tensor_descriptor_given_new2old( throw std::runtime_error("wrong! inconsistent dimension");
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { auto f_ncw = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = wei_desc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t X = wei_desc.GetLengths()[2]; std::size_t X = arg.weight_.GetLengths()[3];
std::size_t Wo = out_desc.GetLengths()[2]; std::size_t Wo = arg.output_.GetLengths()[3];
float v_acc = 0; float v_acc = 0;
...@@ -158,19 +106,11 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -158,19 +106,11 @@ struct ReferenceConvBwdData : public device::BaseOperator
float v_out = 0; float v_out = 0;
float v_wei = 0; float v_wei = 0;
// FIXME hacky
arg.out_element_op_( arg.out_element_op_(
v_out, v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
ck::type_convert<float>(
arg.output_.mData[out_desc.GetOffsetFromMultiIndex(
n, k, wo)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
ck::type_convert<float>(
arg.weight_
.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)]));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -182,28 +122,27 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -182,28 +122,27 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
// FIXME hacky arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)] =
ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
in_desc.GetLengths()[0], arg.input_.GetLengths()[0],
in_desc.GetLengths()[1], arg.input_.GetLengths()[1],
in_desc.GetLengths()[2])( arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { auto f_nchw = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = wei_desc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Y = wei_desc.GetLengths()[2]; std::size_t Y = arg.weight_.GetLengths()[3];
std::size_t X = wei_desc.GetLengths()[3]; std::size_t X = arg.weight_.GetLengths()[4];
std::size_t Ho = out_desc.GetLengths()[2]; std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = out_desc.GetLengths()[3]; std::size_t Wo = arg.output_.GetLengths()[4];
float v_acc = 0; float v_acc = 0;
...@@ -236,21 +175,15 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -236,21 +175,15 @@ struct ReferenceConvBwdData : public device::BaseOperator
float v_out = 0; float v_out = 0;
float v_wei = 0; float v_wei = 0;
// FIXME hacky
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<float>( ck::type_convert<float>(
arg.output_ arg.output_(g, n, k, ho, wo)));
.mData[out_desc.GetOffsetFromMultiIndex(
n, k, ho, wo)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>( ck::type_convert<float>(
arg.weight_ arg.weight_(g, k, c, y, x)));
.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, y, x)]));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -265,31 +198,30 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -265,31 +198,30 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
// FIXME hacky arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(n, c, hi, wi)] =
ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
in_desc.GetLengths()[0], arg.input_.GetLengths()[0],
in_desc.GetLengths()[1], arg.input_.GetLengths()[1],
in_desc.GetLengths()[2], arg.input_.GetLengths()[2],
in_desc.GetLengths()[3])( arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { auto f_ncdhw = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = wei_desc.GetLengths()[0]; std::size_t K = arg.weight_.GetLengths()[1];
std::size_t Z = wei_desc.GetLengths()[2]; std::size_t Z = arg.weight_.GetLengths()[3];
std::size_t Y = wei_desc.GetLengths()[3]; std::size_t Y = arg.weight_.GetLengths()[4];
std::size_t X = wei_desc.GetLengths()[4]; std::size_t X = arg.weight_.GetLengths()[5];
std::size_t Do = out_desc.GetLengths()[2]; std::size_t Do = arg.output_.GetLengths()[3];
std::size_t Ho = out_desc.GetLengths()[3]; std::size_t Ho = arg.output_.GetLengths()[4];
std::size_t Wo = out_desc.GetLengths()[4]; std::size_t Wo = arg.output_.GetLengths()[5];
float v_acc = 0; float v_acc = 0;
...@@ -338,27 +270,15 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -338,27 +270,15 @@ struct ReferenceConvBwdData : public device::BaseOperator
float v_out = 0; float v_out = 0;
float v_wei = 0; float v_wei = 0;
// FIXME hacky
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<float>( ck::type_convert<float>(arg.output_(
arg.output_.mData g, n, k, do_, ho, wo)));
[out_desc
.GetOffsetFromMultiIndex(
n,
k,
do_,
ho,
wo)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>( ck::type_convert<float>(
arg.weight_.mData arg.weight_(g, k, c, z, y, x)));
[wei_desc
.GetOffsetFromMultiIndex(
k, c, z, y, x)]));
v_acc += v_out * v_wei; v_acc += v_out * v_wei;
} }
...@@ -376,17 +296,16 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -376,17 +296,16 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
// FIXME hacky arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(n, c, di, hi, wi)] =
ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
in_desc.GetLengths()[0], arg.input_.GetLengths()[0],
in_desc.GetLengths()[1], arg.input_.GetLengths()[1],
in_desc.GetLengths()[2], arg.input_.GetLengths()[2],
in_desc.GetLengths()[3], arg.input_.GetLengths()[3],
in_desc.GetLengths()[4])( arg.input_.GetLengths()[4],
arg.input_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -7,23 +7,22 @@ ...@@ -7,23 +7,22 @@
#include <sstream> #include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
template <ck::index_t NumDimSpatial, // tensor descriptor in GNCHW/GKCXY/GNKHW dimensional order
typename InLayout, template <ck::index_t NDimSpatial,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -71,97 +70,39 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -71,97 +70,39 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
using Argument = ReferenceConvBwdWeight::Argument; using Argument = ReferenceConvBwdWeight::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
HostTensorDescriptor in_desc = arg.input_.mDesc; arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
HostTensorDescriptor wei_desc = arg.weight_.mDesc; arg.output_.GetNumOfDimension() == NDimSpatial + 3))
HostTensorDescriptor out_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{ {
in_desc = transpose_host_tensor_descriptor_given_new2old( throw std::runtime_error("wrong! inconsistent dimension");
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
} }
// weight if constexpr(NDimSpatial == 1)
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{ {
wei_desc = transpose_host_tensor_descriptor_given_new2old( auto f_kcx = [&](auto g, auto k, auto c, auto x) {
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
if constexpr(NumDimSpatial == 1)
{
auto f_kcx = [&](auto k, auto c, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < out_desc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t wo = 0; wo < out_desc.GetLengths()[2]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[3]; ++wo)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
// FIXME hacky
arg.out_element_op_( arg.out_element_op_(
v_out, v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
ck::type_convert<float>(
arg.output_
.mData[out_desc.GetOffsetFromMultiIndex(n, k, wo)]));
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
ck::type_convert<float>(
arg.input_
.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)]));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -172,33 +113,32 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -172,33 +113,32 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
// FIXME hacky arg.weight_(g, k, c, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)] =
ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcx, make_ParallelTensorFunctor(f_kcx,
wei_desc.GetLengths()[0], arg.weight_.GetLengths()[0],
wei_desc.GetLengths()[1], arg.weight_.GetLengths()[1],
wei_desc.GetLengths()[2])( arg.weight_.GetLengths()[2],
arg.weight_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { auto f_kcyx = [&](auto g, auto k, auto c, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < out_desc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t ho = 0; ho < out_desc.GetLengths()[2]; ++ho) for(std::size_t ho = 0; ho < arg.output_.GetLengths()[3]; ++ho)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t wo = 0; wo < out_desc.GetLengths()[3]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[4]; ++wo)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
...@@ -206,26 +146,19 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -206,26 +146,19 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_out; float v_out;
float v_in; float v_in;
// FIXME hacky
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
ck::type_convert<float>( ck::type_convert<float>(arg.output_(g, n, k, ho, wo)));
arg.output_.mData[out_desc.GetOffsetFromMultiIndex(
n, k, ho, wo)]));
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, hi, wi)]));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -237,38 +170,38 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -237,38 +170,38 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
// FIXME hacky arg.weight_(g, k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(k, c, y, x)] =
ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcyx,
wei_desc.GetLengths()[0], arg.weight_.GetLengths()[0],
wei_desc.GetLengths()[1], arg.weight_.GetLengths()[1],
wei_desc.GetLengths()[2], arg.weight_.GetLengths()[2],
wei_desc.GetLengths()[3])( arg.weight_.GetLengths()[3],
arg.weight_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) { auto f_kczyx = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0; float v_acc = 0;
for(std::size_t n = 0; n < out_desc.GetLengths()[0]; ++n)
for(std::size_t n = 0; n < arg.output_.GetLengths()[1]; ++n)
{ {
for(std::size_t do_ = 0; do_ < out_desc.GetLengths()[2]; ++do_) for(std::size_t do_ = 0; do_ < arg.output_.GetLengths()[3]; ++do_)
{ {
auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) + auto di = static_cast<ck::long_index_t>(do_ * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t ho = 0; ho < out_desc.GetLengths()[3]; ++ho) for(std::size_t ho = 0; ho < arg.output_.GetLengths()[4]; ++ho)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t wo = 0; wo < out_desc.GetLengths()[4]; ++wo) for(std::size_t wo = 0; wo < arg.output_.GetLengths()[5]; ++wo)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
...@@ -277,29 +210,24 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -277,29 +210,24 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
in_desc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
in_desc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[4]) ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{ {
float v_out; float v_out;
float v_in; float v_in;
// FIXME hacky arg.out_element_op_(v_out,
arg.out_element_op_( ck::type_convert<float>(
v_out, arg.output_(g, n, k, do_, ho, wo)));
ck::type_convert<float>(
arg.output_.mData[out_desc.GetOffsetFromMultiIndex(
n, k, do_, ho, wo)]));
// FIXME hacky arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(
v_in, arg.input_(g, n, c, di, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, di, hi, wi)]));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
...@@ -312,17 +240,16 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -312,17 +240,16 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
// FIXME hacky arg.weight_(g, k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(k, c, z, y, x)] =
ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kczyx, make_ParallelTensorFunctor(f_kczyx,
wei_desc.GetLengths()[0], arg.weight_.GetLengths()[0],
wei_desc.GetLengths()[1], arg.weight_.GetLengths()[1],
wei_desc.GetLengths()[2], arg.weight_.GetLengths()[2],
wei_desc.GetLengths()[3], arg.weight_.GetLengths()[3],
wei_desc.GetLengths()[4])( arg.weight_.GetLengths()[4],
arg.weight_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -17,9 +17,28 @@ namespace detail { ...@@ -17,9 +17,28 @@ namespace detail {
template <typename OldLayout> template <typename OldLayout>
std::vector<std::size_t> get_layout_transpose_gnchw_to_old() std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{ {
if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCW> || // NHWC tp NCHW
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCX> || if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKW>) ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NWK>)
{
return {0, 2, 1};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NHWK>)
{
return {0, 3, 1, 2};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::KZYXC> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::NDHWK>)
{
return {0, 4, 1, 2, 3};
}
else if constexpr(ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GKCX> ||
ck::is_same_v<OldLayout, ck::tensor_layout::convolution::GNKW>)
{ {
return {0, 1, 2, 3}; return {0, 1, 2, 3};
} }
...@@ -88,9 +107,25 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa ...@@ -88,9 +107,25 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa
{ {
std::vector<std::size_t> physical_lengths; std::vector<std::size_t> physical_lengths;
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> || if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> || ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>) ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCHW> ||
ck::is_same_v<InLayout, ck::tensor_layout::convolution::GNCDHW>)
{ {
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_), physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_), static_cast<std::size_t>(param.N_),
......
...@@ -80,3 +80,7 @@ struct ConvParam ...@@ -80,3 +80,7 @@ struct ConvParam
} // namespace ck } // namespace ck
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
std::string get_conv_param_parser_helper_msg();
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
...@@ -100,3 +100,77 @@ std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) ...@@ -100,3 +100,77 @@ std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p)
return os; return os;
} }
std::string get_conv_param_parser_helper_msg()
{
std::string msg;
msg += "Following arguments (depending on number of spatial dims):\n"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
" G, N, K, C, \n"
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
" <strides>, (ie Sy, Sx for 2D)\n"
" <dilations>, (ie Dy, Dx for 2D)\n"
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
" <right padding>, (ie RightPy, RightPx for 2D)\n";
return msg;
}
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment