Commit 19173ab7 authored by Chao Liu's avatar Chao Liu
Browse files

add G

parent c3379310
......@@ -25,7 +25,7 @@ void print_helper_msg()
<< "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4: N spatial dimensions (default 2)\n"
<< "Following arguments (depending on number of spatial dims):\n"
<< " N, K, C, \n"
<< " G, N, K, C, \n"
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
<< " <strides>, (ie Sy, Sx for 2D)\n"
......@@ -37,6 +37,7 @@ void print_helper_msg()
ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
......@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c
}
return ck::utils::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
......@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
const auto in_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param);
const auto wei_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
// hacky, hardcoded for 2d NHWK
const auto bias_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.N_),
#if 0
const auto in_g_n_c_wis_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
#else
const auto in_g_n_wis_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.input_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto wei_g_k_xs_c_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.K_),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.filter_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.C_)});
const auto bias_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)},
std::vector<std::size_t>{0, 0, 0, 1});
std::vector<std::size_t>{0, 0, 0, 0, 1});
Tensor<InDataType> in(in_desc);
Tensor<WeiDataType> wei(wei_desc);
Tensor<OutDataType> bias(bias_desc);
Tensor<OutDataType> out_host(out_desc);
Tensor<OutDataType> out_device(out_desc);
const auto out_g_n_wos_k_desc = HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(conv_param.G_),
static_cast<std::size_t>(conv_param.N_),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[0]),
static_cast<std::size_t>(conv_param.output_spatial_lengths_[1]),
static_cast<std::size_t>(conv_param.K_)});
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
const auto in_g_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_g_n_wis_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto wei_g_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_g_k_xs_c_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto bias_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
const auto out_g_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_g_n_wos_k_desc, std::vector<ck::index_t>{0, 1, 4, 2, 3});
#endif
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> bias(bias_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
......@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_n_c_wis_desc = in_desc;
HostTensorDescriptor wei_k_c_xs_desc = wei_desc;
HostTensorDescriptor bias_n_k_wos_desc = bias_desc;
HostTensorDescriptor out_n_k_wos_desc = out_desc;
// input
if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_n_c_wis_desc = transpose_host_tensor_descriptor_given_new2old(
in_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_k_c_xs_desc = transpose_host_tensor_descriptor_given_new2old(
wei_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 2, 1});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 3, 1, 2});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(ck::is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
out_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
bias_n_k_wos_desc = transpose_host_tensor_descriptor_given_new2old(
bias_desc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 2> a_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 2> b_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> d_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 2> e_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
......@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification,
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_n_c_wis_desc.GetLengths(), a_n_c_wis_lengths);
copy(in_n_c_wis_desc.GetStrides(), a_n_c_wis_strides);
copy(wei_k_c_xs_desc.GetLengths(), b_k_c_xs_lengths);
copy(wei_k_c_xs_desc.GetStrides(), b_k_c_xs_strides);
copy(bias_n_k_wos_desc.GetLengths(), d_n_k_wos_lengths);
copy(bias_n_k_wos_desc.GetStrides(), d_n_k_wos_strides);
copy(out_n_k_wos_desc.GetLengths(), e_n_k_wos_lengths);
copy(out_n_k_wos_desc.GetStrides(), e_n_k_wos_strides);
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(bias_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
......@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_n_c_wis_lengths,
a_n_c_wis_strides,
b_k_c_xs_lengths,
b_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 2>, 1>{{d_n_k_wos_strides}},
e_n_k_wos_lengths,
e_n_k_wos_strides,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification,
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<OutDataType> c_host(out_desc);
Tensor<OutDataType> c_host(out_g_n_k_wos_desc);
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InLayout,
......@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification,
ref_invoker.Run(ref_argument);
for(int n = 0; n < out_host.mDesc.GetLengths()[0]; n++)
for(int g = 0; g < out_host.mDesc.GetLengths()[0]; g++)
{
for(int ho = 0; ho < out_host.mDesc.GetLengths()[1]; ho++)
for(int n = 0; n < out_host.mDesc.GetLengths()[1]; n++)
{
for(int wo = 0; wo < out_host.mDesc.GetLengths()[2]; wo++)
for(int k = 0; k < out_host.mDesc.GetLengths()[2]; k++)
{
for(int k = 0; k < out_host.mDesc.GetLengths()[3]; k++)
for(int ho = 0; ho < out_host.mDesc.GetLengths()[3]; ho++)
{
out_element_op(
out_host(n, ho, wo, k), c_host(n, ho, wo, k), bias(n, ho, wo, k));
for(int wo = 0; wo < out_host.mDesc.GetLengths()[4]; wo++)
{
out_element_op(out_host(g, n, k, ho, wo),
c_host(g, n, k, ho, wo),
bias(g, n, k, ho, wo));
}
}
}
}
......
......@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
int num_dim_spatial = 2;
ck::utils::conv::ConvParam params{
2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
......
......@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 2>& a_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 2>& b_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 2>, NumDTensor>& ds_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 2>& e_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
......
......@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator
{
using Argument = ReferenceConvFwd::Argument;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float Run(const Argument& arg)
{
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor in_desc = arg.input_.mDesc;
HostTensorDescriptor wei_desc = arg.weight_.mDesc;
HostTensorDescriptor out_desc = arg.output_.mDesc;
// input
if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<InLayout, ck::tensor_layout::convolution::NDHWC>)
{
in_desc = transpose_host_tensor_descriptor_given_new2old(
arg.input_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// weight
if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<WeiLayout, ck::tensor_layout::convolution::KZYXC>)
{
wei_desc = transpose_host_tensor_descriptor_given_new2old(
arg.weight_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
// output
if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 2, 1});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 3, 1, 2});
}
else if constexpr(is_same_v<OutLayout, ck::tensor_layout::convolution::NDHWK>)
{
out_desc = transpose_host_tensor_descriptor_given_new2old(
arg.output_.mDesc, std::vector<std::size_t>{0, 4, 1, 2, 3});
}
if constexpr(NumDimSpatial == 1)
{
auto f_ncw = [&](auto n, auto k, auto wo) {
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t x = 0; x < wei_desc.GetLengths()[2]; ++x)
for(std::size_t x = 0; x < arg.weight_.GetLengths()[3]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[2])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
float v_in;
float v_wei;
// FIXME hacky
arg.in_element_op_(
v_in,
ck::type_convert<float>(
arg.input_
.mData[in_desc.GetOffsetFromMultiIndex(n, c, wi)]));
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
// FIXME hacky
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_
.mData[wei_desc.GetOffsetFromMultiIndex(k, c, x)]));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_in * v_wei;
}
......@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, wo})] =
ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_ncw,
out_desc.GetLengths()[0],
out_desc.GetLengths()[1],
out_desc.GetLengths()[2])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t y = 0; y < wei_desc.GetLengths()[2]; ++y)
for(std::size_t y = 0; y < arg.weight_.GetLengths()[3]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[3]; ++x)
for(std::size_t x = 0; x < arg.weight_.GetLengths()[4]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
......@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < in_desc.GetLengths()[2] &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[3])
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
float v_in;
float v_wei;
// FIXME hacky
arg.in_element_op_(
v_in,
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, hi, wi)]));
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
// FIXME hacky
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, y, x)]));
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x)));
v_acc += v_in * v_wei;
}
......@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
out_desc.GetLengths()[0],
out_desc.GetLengths()[1],
out_desc.GetLengths()[2],
out_desc.GetLengths()[3])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < wei_desc.GetLengths()[1]; ++c)
for(std::size_t c = 0; c < arg.weight_.GetLengths()[2]; ++c)
{
for(std::size_t z = 0; z < wei_desc.GetLengths()[2]; ++z)
for(std::size_t z = 0; z < arg.weight_.GetLengths()[3]; ++z)
{
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < wei_desc.GetLengths()[3]; ++y)
for(std::size_t y = 0; y < arg.weight_.GetLengths()[4]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < wei_desc.GetLengths()[4]; ++x)
for(std::size_t x = 0; x < arg.weight_.GetLengths()[5]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
......@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
in_desc.GetLengths()[2] &&
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
in_desc.GetLengths()[3] &&
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < in_desc.GetLengths()[4])
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
float v_in;
float v_wei;
// FIXME hacky
arg.in_element_op_(
v_in,
arg.in_element_op_(v_in,
ck::type_convert<float>(
arg.input_.mData[in_desc.GetOffsetFromMultiIndex(
n, c, di, hi, wi)]));
arg.input_(g, n, c, di, hi, wi)));
// FIXME hacky
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(
arg.weight_.mData[wei_desc.GetOffsetFromMultiIndex(
k, c, z, y, x)]));
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x)));
v_acc += v_in * v_wei;
}
......@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.out_element_op_(v_out, v_acc);
// FIXME hacky
arg.output_.mData[out_desc.GetOffsetFromMultiIndex({n, k, d_o, ho, wo})] =
ck::type_convert<OutDataType>(v_out);
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
out_desc.GetLengths()[0],
out_desc.GetLengths()[1],
out_desc.GetLengths()[2],
out_desc.GetLengths()[3],
out_desc.GetLengths()[4])(
make_ParallelTensorFunctor(func,
arg.output_.GetLengths()[0],
arg.output_.GetLengths()[1],
arg.output_.GetLengths()[2],
arg.output_.GetLengths()[3],
arg.output_.GetLengths()[4],
arg.output_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......
......@@ -18,6 +18,7 @@ struct ConvParam
{
ConvParam();
ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
......@@ -29,6 +30,7 @@ struct ConvParam
const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_;
ck::index_t G_;
ck::index_t N_;
ck::index_t K_;
ck::index_t C_;
......@@ -50,20 +52,22 @@ struct ConvParam
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
return sizeof(InDataType) * (N_ * C_ *
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::end(input_spatial_lengths_),
std::begin(input_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(WeiDataType) * (K_ * C_ *
sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::end(filter_spatial_lengths_),
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>())) +
sizeof(OutDataType) * (N_ * K_ *
sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
......
......@@ -256,6 +256,10 @@ struct Tensor
return *this;
}
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); }
void SetZero()
{
for(auto& v : mData)
......
......@@ -10,6 +10,7 @@ namespace utils {
namespace conv {
ConvParam::ConvParam(ck::index_t n_dim,
ck::index_t group_count,
ck::index_t n_batch,
ck::index_t n_out_channels,
ck::index_t n_in_channels,
......@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
K_(n_out_channels),
C_(n_in_channels),
......@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
}
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
......@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
std::size_t ConvParam::GetFlops() const
{
// 2 * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * N_ * K_ * C_ *
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
std::begin(output_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()) *
std::accumulate(std::begin(filter_spatial_lengths_),
std::end(filter_spatial_lengths_),
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
......@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const
std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p)
{
os << "ConvParam {"
<< "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nN: " << p.N_ << "\nK: " << p.K_
<< "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_
<< "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_
<< "\nK: " << p.K_ << "\nC: " << p.C_
<< "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_
<< "\ninput_spatial_lengths: " << p.input_spatial_lengths_
<< "\nconv_filter_strides: " << p.conv_filter_strides_
<< "\nconv_filter_dilations: " << p.conv_filter_dilations_
<< "\ninput_left_pads: " << p.input_left_pads_
<< "\ninput_right_pads: " << p.input_right_pads_;
<< "\ninput_right_pads: " << p.input_right_pads_ << "}\n";
return os;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment