Commit 19173ab7 authored by Chao Liu's avatar Chao Liu
Browse files

add G

parent c3379310
...@@ -25,7 +25,7 @@ void print_helper_msg() ...@@ -25,7 +25,7 @@ void print_helper_msg()
<< "arg3: time kernel (0=no, 1=yes)\n" << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n" << "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n" << "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n" << " G, N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n" << " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n" << " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n" << " <strides>, (ie Sy, Sx for 2D)\n"
...@@ -37,6 +37,7 @@ void print_helper_msg() ...@@ -37,6 +37,7 @@ void print_helper_msg()
ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[]) ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{ {
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]); const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]); const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]); const ck::index_t C = std::stoi(argv[arg_idx++]);
...@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c ...@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c
} }
return ck::utils::conv::ConvParam{num_dim_spatial, return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N, N,
K, K,
C, C,
...@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification, ...@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
const auto in_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param); #if 0
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param); const auto in_g_n_c_wis_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param); const auto wei_g_k_c_xs_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
// hacky, hardcoded for 2d NHWK #else
const auto bias_desc = HostTensorDescriptor( const auto in_g_n_wis_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.N_), std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto wei_g_k_xs_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.K_),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto bias_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]), static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]), static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)}, static_cast<std::size_t>(conv_param.K_)},
std::vector<std::size_t>{0, 0, 0, 1}); std::vector<std::size_t>{0, 0, 0, 0, 1});
const auto out_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)});
Tensor<InDataType> in(in_desc); // tensor descriptor in NCHW/KXYC/NKHW dimensional order
Tensor<WeiDataType> wei(wei_desc); const auto in_g_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
Tensor<OutDataType> bias(bias_desc); in_g_n_wis_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
Tensor<OutDataType> out_host(out_desc); const auto wei_g_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
Tensor<OutDataType> out_device(out_desc); wei_g_k_xs_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto bias_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto out_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
#endif
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl; std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl;
...@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification, ...@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
HostTensorDescriptor in_n_c_wis_desc = in_desc; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
HostTensorDescriptor wei_k_c_xs_desc = wei_desc; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
HostTensorDescriptor bias_n_k_wos_desc = bias_desc; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
HostTensorDescriptor out_n_k_wos_desc = out_desc; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
// input std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>) std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification, ...@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification,
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_n_c_wis_desc.GetLengths(), a_n_c_wis_lengths); copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides); copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths); copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides); copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_n_k_wos_desc.GetLengths(), d_n_k_wos_lengths); copy(bias_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(bias_n_k_wos_desc.GetStrides(), d_n_k_wos_strides); copy(bias_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides); copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification, ...@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_n_c_wis_strides, a_g_n_c_wis_strides,
b_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_lengths}}, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_strides}}, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification, ...@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification,
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<OutDataType> c_host(out_desc); Tensor<OutDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InLayout, InLayout,
...@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification, ...@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int n = 0; n < out_host.mDesc.GetLengths()[0]; n++) for(int g = 0; g < out_host.mDesc.GetLengths()[0]; g++)
{ {
for(int ho = 0; ho < out_host.mDesc.GetLengths()[1]; ho++) for(int n = 0; n < out_host.mDesc.GetLengths()[1]; n++)
{ {
for(int wo = 0; wo < out_host.mDesc.GetLengths()[2]; wo++) for(int k = 0; k < out_host.mDesc.GetLengths()[2]; k++)
{ {
for(int k = 0; k < out_host.mDesc.GetLengths()[3]; k++) for(int ho = 0; ho < out_host.mDesc.GetLengths()[3]; ho++)
{ {
out_element_op( for(int wo = 0; wo < out_host.mDesc.GetLengths()[4]; wo++)
out_host(n, ho, wo, k), c_host(n, ho, wo, k), bias(n, ho, wo, k)); {
out_element_op(out_host(g, n, k, ho, wo),
c_host(g, n, k, ho, wo),
bias(g, n, k, ho, wo));
}
} }
} }
} }
......
...@@ -138,7 +138,7 @@ int main(int argc, char* argv[]) ...@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
int num_dim_spatial = 2; int num_dim_spatial = 2;
ck::utils::conv::ConvParam params{ ck::utils::conv::ConvParam params{
2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1) if(argc == 1)
{ {
......
...@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator ...@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator
const void* p_b, const void* p_b,
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides, const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
......
...@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator
{ {
using Argument = ReferenceConvFwd::Argument; using Argument = ReferenceConvFwd::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
// tensor descriptor in NCHW/KXYC/NKHW dimensional order // tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_desc = arg.input_.mDesc;
HostTensorDescriptor wei_desc = arg.weight_.mDesc;
HostTensorDescriptor out_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
if constexpr(NumDimSpatial == 1) if constexpr(NumDimSpatial == 1)
{ {
auto f_ncw = [&](auto n, auto k, auto wo) { auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t x = 0; x < wei_desc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
ck::type_convert<float>(
arg.input_
.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
ck::type_convert<float>(
arg.weight_
.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2])( arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 2) else if constexpr(NumDimSpatial == 2)
{ {
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t y = 0; y < wei_desc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
...@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 && if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] && ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky
arg.in_element_op_( arg.in_element_op_(
v_in, v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2], arg.output_.GetLengths()[2],
out_desc.GetLengths()[3])( arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3) else if constexpr(NumDimSpatial == 3)
{ {
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{ {
for(std::size_t z = 0; z < wei_desc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{ {
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < wei_desc.GetLengths()[3]; ++y) for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[4]; ++x) for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
...@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
in_desc.GetLengths()[2] && arg.input_.GetLengths()[3] &&
hi >= 0 && hi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(hi) <
in_desc.GetLengths()[3] && arg.input_.GetLengths()[4] &&
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[4]) ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
// FIXME hacky arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(
v_in, arg.input_(g, n, c, di, hi, wi)));
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, di, hi, wi)]));
// FIXME hacky
arg.wei_element_op_( arg.wei_element_op_(
v_wei, v_wei,
ck::type_convert<float>( ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, z, y, x)]));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
// FIXME hacky arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, d_o, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(func,
out_desc.GetLengths()[0], arg.output_.GetLengths()[0],
out_desc.GetLengths()[1], arg.output_.GetLengths()[1],
out_desc.GetLengths()[2], arg.output_.GetLengths()[2],
out_desc.GetLengths()[3], arg.output_.GetLengths()[3],
out_desc.GetLengths()[4])( arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
......
...@@ -18,6 +18,7 @@ struct ConvParam ...@@ -18,6 +18,7 @@ struct ConvParam
{ {
ConvParam(); ConvParam();
ConvParam(ck::index_t n_dim, ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch, ck::index_t n_batch,
ck::index_t n_out_channels, ck::index_t n_out_channels,
ck::index_t n_in_channels, ck::index_t n_in_channels,
...@@ -29,6 +30,7 @@ struct ConvParam ...@@ -29,6 +30,7 @@ struct ConvParam
const std::vector<ck::index_t>& right_pads); const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_; ck::index_t num_dim_spatial_;
ck::index_t G_;
ck::index_t N_; ck::index_t N_;
ck::index_t K_; ck::index_t K_;
ck::index_t C_; ck::index_t C_;
...@@ -50,20 +52,22 @@ struct ConvParam ...@@ -50,20 +52,22 @@ struct ConvParam
template <typename InDataType, typename WeiDataType, typename OutDataType> template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const std::size_t GetByte() const
{ {
// sizeof(InDataType) * (N * C * <input spatial lengths product>) + // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) + // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>); // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N_ * C_ * return sizeof(InDataType) *
std::accumulate(std::begin(input_spatial_lengths_), (G_ * N_ * C_ *
std::end(input_spatial_lengths_), std::accumulate(std::begin(input_spatial_lengths_),
static_cast<std::size_t>(1), std::begin(input_spatial_lengths_) + num_dim_spatial_,
std::multiplies<std::size_t>())) + static_cast<std::size_t>(1),
sizeof(WeiDataType) * (K_ * C_ * std::multiplies<std::size_t>())) +
std::accumulate(std::begin(filter_spatial_lengths_), sizeof(WeiDataType) *
std::end(filter_spatial_lengths_), (G_ * K_ * C_ *
static_cast<std::size_t>(1), std::accumulate(std::begin(filter_spatial_lengths_),
std::multiplies<std::size_t>())) + std::begin(filter_spatial_lengths_) + num_dim_spatial_,
sizeof(OutDataType) * (N_ * K_ * static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_), std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_), std::end(output_spatial_lengths_),
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
......
...@@ -256,6 +256,10 @@ struct Tensor ...@@ -256,6 +256,10 @@ struct Tensor
return *this; return *this;
} }
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
void SetZero() void SetZero()
{ {
for(auto& v : mData) for(auto& v : mData)
......
...@@ -10,6 +10,7 @@ namespace utils { ...@@ -10,6 +10,7 @@ namespace utils {
namespace conv { namespace conv {
ConvParam::ConvParam(ck::index_t n_dim, ConvParam::ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch, ck::index_t n_batch,
ck::index_t n_out_channels, ck::index_t n_out_channels,
ck::index_t n_in_channels, ck::index_t n_in_channels,
...@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
const std::vector<ck::index_t>& left_pads, const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads) const std::vector<ck::index_t>& right_pads)
: num_dim_spatial_(n_dim), : num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch), N_(n_batch),
K_(n_out_channels), K_(n_out_channels),
C_(n_in_channels), C_(n_in_channels),
...@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
} }
ConvParam::ConvParam() ConvParam::ConvParam()
: ConvParam::ConvParam(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}) : ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{ {
} }
...@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const ...@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
std::size_t ConvParam::GetFlops() const std::size_t ConvParam::GetFlops() const
{ {
// 2 * N * K * C * <output spatial lengths product> * <filter spatial lengths product> // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * N_ * K_ * C_ * return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_), std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_), std::begin(output_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>()) * std::multiplies<std::size_t>()) *
std::accumulate(std::begin(filter_spatial_lengths_), std::accumulate(std::begin(filter_spatial_lengths_),
std::end(filter_spatial_lengths_), std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1), static_cast<std::size_t>(1),
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
} }
...@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const ...@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p)
{ {
os << "ConvParam {" os << "ConvParam {"
<< "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_
<< "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\nK: " << p.K_ << "\nC: " << p.C_
<< "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_
<< "\ninput_spatial_lengths: " << p.input_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_
<< "\nconv_filter_strides: " << p.conv_filter_strides_ << "\nconv_filter_strides: " << p.conv_filter_strides_
<< "\nconv_filter_dilations: " << p.conv_filter_dilations_ << "\nconv_filter_dilations: " << p.conv_filter_dilations_
<< "\ninput_left_pads: " << p.input_left_pads_ << "\ninput_left_pads: " << p.input_left_pads_
<< "\ninput_right_pads: " << p.input_right_pads_; << "\ninput_right_pads: " << p.input_right_pads_ << "}\n";
return os; return os;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment